Add lua argument checker

This commit is contained in:
Thomas Goyne 2014-07-21 17:11:55 -07:00
parent 74a215f642
commit e3c60514cd
6 changed files with 191 additions and 57 deletions

View file

@ -0,0 +1,78 @@
-- Copyright (c) 2014, Thomas Goyne <plorkyeran@aegisub.org>
--
-- Permission to use, copy, modify, and distribute this software for any
-- purpose with or without fee is hereby granted, provided that the above
-- copyright notice and this permission notice appear in all copies.
--
-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-- ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
-- OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
--
-- Aegisub Project http://www.aegisub.org/
assert = assert
error = error
select = select
tostring = tostring
type = type
is_type = (v, ty, expected) ->
ty == expected or (ty == 'table' and v.__class and v.__class.__name == expected)
(argfmt) ->
assert type(argfmt) == 'string'
min_args = 0
max_args = 0
checks = {}
for arg in argfmt\gmatch '[^ ]+'
if arg == '...'
max_args = nil
break
max_args += 1
optional = arg\sub(1, 1) == '?'
if optional
arg = arg\sub 2
else
min_args += 1
if arg\find '|'
types = [ty for ty in arg\gmatch '[^|]+']
checks[max_args] = (i, v) ->
if v == nil
return if optional
error "Argument ##{i} should be a #{arg}, is nil", 4
ty = type v
for argtype in *types
return if is_type v, ty, argtype
error "Argument ##{i} should be a #{arg}, is #{ty} (#{v})", 3
else
checks[max_args] = (i, v) ->
if v == nil
return if optional
error "Argument ##{i} should be a #{arg}, is nil", 4
ty = type v
return if is_type v, ty, arg
error "Argument ##{i} should be a #{arg}, is #{ty} (#{v})", 3
(fn) -> (...) ->
arg_count = select '#', ...
if arg_count < min_args or (max_args and arg_count > max_args)
if min_args == max_args
error "Expected #{min_args} arguments, got #{arg_count}", 3
else if max_args
error "Expected #{min_args}-#{max_args} arguments, got #{arg_count}", 3
else
error "Expected at least #{min_args} arguments, got #{arg_count}", 3
for i=1,arg_count
if not checks[i] then break
checks[i] i, select i, ...
fn ...

View file

@ -15,6 +15,8 @@
-- Aegisub Project http://www.aegisub.org/ -- Aegisub Project http://www.aegisub.org/
impl = require 'aegisub.__lfs_impl' impl = require 'aegisub.__lfs_impl'
check = require 'aegisub.argcheck'
ffi = require 'ffi' ffi = require 'ffi'
ffi_util = require 'aegisub.ffi' ffi_util = require 'aegisub.ffi'
@ -29,7 +31,7 @@ number_ret = (f) -> (...) ->
res, err = f ... res, err = f ...
tonumber(res), err tonumber(res), err
attributes = (path, field) -> attributes = check'string ?string' (path, field) ->
switch field switch field
when 'mode' when 'mode'
res, err = impl.get_mode path res, err = impl.get_mode path
@ -62,7 +64,7 @@ class dir_iter
if err then error err, 2 if err then error err, 2
ffi_util.string str ffi_util.string str
dir = (path) -> dir = check'string' (path) ->
obj, err = impl.dir_new path obj, err = impl.dir_new path
if err if err
error 2, err error 2, err
@ -71,10 +73,10 @@ dir = (path) ->
return { return {
:attributes :attributes
chdir: number_ret impl.chdir chdir: check'string' number_ret impl.chdir
currentdir: string_ret impl.currentdir currentdir: check'' string_ret impl.currentdir
:dir :dir
mkdir: number_ret impl.mkdir mkdir: check'string' number_ret impl.mkdir
rmdir: number_ret impl.rmdir rmdir: check'string'number_ret impl.rmdir
touch: number_ret impl.touch touch: check'string'number_ret impl.touch
} }

View file

@ -20,6 +20,7 @@ type = type
bit = require 'bit' bit = require 'bit'
ffi = require 'ffi' ffi = require 'ffi'
ffi_util = require 'aegisub.ffi' ffi_util = require 'aegisub.ffi'
check = require 'aegisub.argcheck'
ffi.cdef[[ ffi.cdef[[
typedef struct agi_re_flag { typedef struct agi_re_flag {
@ -91,12 +92,6 @@ unpack_args = (...) ->
return 0, ... unless flags_start return 0, ... unless flags_start
process_flags(select flags_start, ...), select_first flags_start - 1, ... process_flags(select flags_start, ...), select_first flags_start - 1, ...
-- Typecheck a variable and throw an error if it fails
check_arg = (arg, expected_type, argn, func_name, level) ->
if type(arg) != expected_type
error "Argument #{argn} to #{func_name} should be a '#{expected_type}', is '#{type(arg)}' (#{arg})",
level + 1
-- Replace a match with the value returned from func when passed the match -- Replace a match with the value returned from func when passed the match
replace_match = (match, func, str, last, acc) -> replace_match = (match, func, str, last, acc) ->
-- Copy everything between the last match and this match -- Copy everything between the last match and this match
@ -155,9 +150,7 @@ class RegEx
new: (@_regex, @_level) => new: (@_regex, @_level) =>
gsplit: (str, skip_empty, max_split) => gsplit: check'RegEx string ?boolean ?number' (str, skip_empty, max_split) =>
@_check_self!
check_arg str, 'string', 2, 'gsplit', @_level
if not max_split or max_split <= 0 then max_split = str\len() if not max_split or max_split <= 0 then max_split = str\len()
start = 0 start = 0
@ -187,15 +180,10 @@ class RegEx
do_split do_split
split: (str, skip_empty, max_split) => split: check'RegEx string ?boolean ?number' (str, skip_empty, max_split) =>
@_check_self!
check_arg str, 'string', 2, 'split', @_level
[v for v in @gsplit str, skip_empty, max_split] [v for v in @gsplit str, skip_empty, max_split]
gfind: (str) => gfind: check'RegEx string' (str) =>
@_check_self!
check_arg str, 'string', 2, 'gfind', @_level
start = 0 start = 0
-> ->
first, last = search(@_regex, str, start) first, last = search(@_regex, str, start)
@ -204,31 +192,19 @@ class RegEx
start = if last > start then last else start + 1 start = if last > start then last else start + 1
str\sub(first, last), first, last str\sub(first, last), first, last
find: (str) => find: check'RegEx string' (str) =>
@_check_self!
check_arg str, 'string', 2, 'find', @_level
ret = [str: s, first: f, last: l for s, f, l in @gfind(str)] ret = [str: s, first: f, last: l for s, f, l in @gfind(str)]
next(ret) and ret next(ret) and ret
sub: (str, repl, max_count) => sub: check'RegEx string string|function ?number' (str, repl, max_count) =>
@_check_self!
check_arg str, 'string', 2, 'sub', @_level
if max_count != nil
check_arg max_count, 'number', 4, 'sub', @_level
max_count = str\len() + 1 if not max_count or max_count == 0 max_count = str\len() + 1 if not max_count or max_count == 0
if type(repl) == 'function' if type(repl) == 'function'
do_replace_fun @, repl, str, max_count do_replace_fun @, repl, str, max_count
elseif type(repl) == 'string' elseif type(repl) == 'string'
replace @_regex, repl, str, max_count replace @_regex, repl, str, max_count
else
error "Argument 2 to sub should be a string or function, is '#{type(repl)}' (#{repl})", @_level
gmatch: (str, start) => gmatch: check'RegEx string ?number' (str, start) =>
@_check_self!
check_arg str, 'string', 2, 'gmatch', @_level
start = if start then start - 1 else 0 start = if start then start - 1 else 0
m = match @_regex, str, start m = match @_regex, str, start
@ -245,10 +221,7 @@ class RegEx
last: last + start last: last + start
} }
match: (str, start) => match: check'RegEx string ?number' (str, start) =>
@_check_self!
check_arg(str, 'string', 2, 'match', @_level)
ret = [v for v in @gmatch str, start] ret = [v for v in @gmatch str, start]
-- Return nil rather than a empty table so that if re.match(...) works -- Return nil rather than a empty table so that if re.match(...) works
return nil if next(ret) == nil return nil if next(ret) == nil
@ -271,16 +244,13 @@ invoke = (str, pattern, fn, flags, ...) ->
compiled_regex[fn](compiled_regex, str, ...) compiled_regex[fn](compiled_regex, str, ...)
-- Generate a static version of a method with arg type checking -- Generate a static version of a method with arg type checking
gen_wrapper = (impl_name) -> (str, pattern, ...) -> gen_wrapper = (impl_name) -> check'string string ...' (str, pattern, ...) ->
check_arg str, 'string', 1, impl_name, 2
check_arg pattern, 'string', 2, impl_name, 2
invoke str, pattern, impl_name, unpack_args ... invoke str, pattern, impl_name, unpack_args ...
-- And now at last the actual public API -- And now at last the actual public API
do do
re = { re = {
compile: (pattern, ...) -> compile: check'string ...' (pattern, ...) ->
check_arg pattern, 'string', 1, 'compile', 2
real_compile pattern, 2, process_flags(...), 2 real_compile pattern, 2, process_flags(...), 2
split: gen_wrapper 'split' split: gen_wrapper 'split'

View file

@ -29,11 +29,13 @@
-- http://www.ietf.org/rfc/rfc2279.txt -- http://www.ietf.org/rfc/rfc2279.txt
impl = require 'aegisub.__unicode_impl' impl = require 'aegisub.__unicode_impl'
check = require 'aegisub.argcheck'
ffi = require 'ffi' ffi = require 'ffi'
ffi_util = require 'aegisub.ffi' ffi_util = require 'aegisub.ffi'
err_buff = ffi.new 'char *[1]' err_buff = ffi.new 'char *[1]'
conv_func = (f) -> (str) -> conv_func = (f) -> check'string' (str) ->
err_buff[0] = nil err_buff[0] = nil
result = f str, err_buff result = f str, err_buff
errmsg = ffi_util.string err_buff[0] errmsg = ffi_util.string err_buff[0]
@ -44,7 +46,7 @@ conv_func = (f) -> (str) ->
local unicode local unicode
unicode = unicode =
-- Return the number of bytes occupied by the character starting at the i'th byte in s -- Return the number of bytes occupied by the character starting at the i'th byte in s
charwidth: (s, i) -> charwidth: check'string ?number' (s, i) ->
b = s\byte i or 1 b = s\byte i or 1
-- FIXME, something in karaskel results in this case, shouldn't happen -- FIXME, something in karaskel results in this case, shouldn't happen
-- What would "proper" behaviour be? Zero? Or just explode? -- What would "proper" behaviour be? Zero? Or just explode?
@ -55,7 +57,7 @@ unicode =
else 4 else 4
-- Returns an iterator function for iterating over the characters in s -- Returns an iterator function for iterating over the characters in s
chars: (s) -> chars: check'string' (s) ->
curchar, i = 0, 1 curchar, i = 0, 1
-> ->
return if i > s\len() return if i > s\len()
@ -67,13 +69,13 @@ unicode =
-- Returns the number of characters in s -- Returns the number of characters in s
-- Runs in O(s:len()) time! -- Runs in O(s:len()) time!
len: (s) -> len: check'string' (s) ->
n = 0 n = 0
n += 1 for c in unicode.chars s n += 1 for c in unicode.chars s
n n
-- Get codepoint of first char in s -- Get codepoint of first char in s
codepoint: (s) -> codepoint: check'string' (s) ->
-- Basic case, ASCII -- Basic case, ASCII
b = s\byte 1 b = s\byte 1
return b if b < 128 return b if b < 128

View file

@ -20,14 +20,16 @@ sformat = string.format
tonumber = tonumber tonumber = tonumber
type = type type = type
check = require 'aegisub.argcheck'
local * local *
-- Make a shallow copy of a table -- Make a shallow copy of a table
copy = (tbl) -> {k, v for k, v in pairs tbl} copy = check'table' (tbl) -> {k, v for k, v in pairs tbl}
-- Make a deep copy of a table -- Make a deep copy of a table
-- Retains equality of table references inside the copy and handles self-referencing structures -- Retains equality of table references inside the copy and handles self-referencing structures
deep_copy = (tbl) -> deep_copy = check'table' (tbl) ->
seen = {} seen = {}
copy = (val) -> copy = (val) ->
return val if type(val) != 'table' return val if type(val) != 'table'
@ -44,7 +46,7 @@ ass_alpha = (a) -> sformat "&H%02X&", a
ass_style_color = (r, g, b, a) -> sformat "&H%02X%02X%02X%02X", a, b, g, r ass_style_color = (r, g, b, a) -> sformat "&H%02X%02X%02X%02X", a, b, g, r
-- Extract colour components of an ASS colour -- Extract colour components of an ASS colour
extract_color = (s) -> extract_color = check'string' (s) ->
local a, b, g, r local a, b, g, r
-- Try a style first -- Try a style first
@ -68,10 +70,10 @@ extract_color = (s) ->
return tonumber(r, 16), tonumber(g, 16) or 0, tonumber(b, 16) or 0, tonumber(a, 16) or 0 return tonumber(r, 16), tonumber(g, 16) or 0, tonumber(b, 16) or 0, tonumber(a, 16) or 0
-- Create an alpha override code from a style definition colour code -- Create an alpha override code from a style definition colour code
alpha_from_style = (scolor) -> ass_alpha select 4, extract_color scolor alpha_from_style = check'string' (scolor) -> ass_alpha select 4, extract_color scolor
-- Create an colour override code from a style definition colour code -- Create an colour override code from a style definition colour code
color_from_style = (scolor) -> color_from_style = check'string' (scolor) ->
r, g, b = extract_color scolor r, g, b = extract_color scolor
ass_color r or 0, g or 0, b or 0 ass_color r or 0, g or 0, b or 0

View file

@ -0,0 +1,80 @@
-- Copyright (c) 2014, Thomas Goyne <plorkyeran@aegisub.org>
--
-- Permission to use, copy, modify, and distribute this software for any
-- purpose with or without fee is hereby granted, provided that the above
-- copyright notice and this permission notice appear in all copies.
--
-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-- ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
-- OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
--
-- Aegisub Project http://www.aegisub.org/
check = require 'aegisub.argcheck'
describe 'argcheck', ->
it 'should permit simple valid calls', ->
assert.has_no.errors -> (check'string' ->) ''
assert.has_no.errors -> (check'number' ->) 10
assert.has_no.errors -> (check'boolean' ->) true
assert.has_no.errors -> (check'table' ->) {}
it 'should support multiple arguments', ->
assert.has_no.errors -> (check'string number' ->) '', 10
assert.has_no.errors -> (check'string table number' ->) '', {}, 10
it 'should support moonscript classes', ->
class Foo
assert.has_no.errors -> (check'Foo' (->) Foo)
it 'should support optional arguments', ->
assert.has_no.errors -> (check'?number' ->) nil
assert.has_no.errors -> (check'?number ?number' ->) 5
it 'should support ...', ->
assert.has_no.errors -> (check'number ...' ->) 5
assert.has_no.errors -> (check'number ...' ->) 5, 5
assert.has_no.errors -> (check'number ...' ->) 5, 5, ''
it 'should support alternates', ->
assert.has_no.errors -> (check'number|string' ->) 5
assert.has_no.errors -> (check'number|string' ->) ''
it 'should support optional alternates', ->
assert.has_no.errors -> (check'?number|string' ->) 5
assert.has_no.errors -> (check'?number|string' ->) ''
assert.has_no.errors -> (check'?number|string' ->) nil
it 'should reject simple invalid calls', ->
assert.has.errors -> (check'string' ->) 10
assert.has.errors -> (check'number' ->) ''
it 'should reject inccorect numbers of arguments', ->
assert.has.errors -> (check'string number' ->) ''
assert.has_no.errors -> (check'string ?number' ->) ''
assert.has.errors -> (check'string number' ->) '', 5, 5
it 'should reject non-optional nil arguments', ->
assert.has.errors -> (check'string number' ->) nil, nil
it 'should reject invalid matches with alternates', ->
assert.has.errors -> (check'number|string' ->) {}
it 'should report the correct error levels', ->
valid_err_loc = (fn) ->
_, err = pcall fn
err\find('tests/modules/argcheck.moon') != nil
assert.is.true valid_err_loc -> (check'number' ->) {}
assert.is.true valid_err_loc -> (check'number' ->) nil
assert.is.true valid_err_loc -> (check'number|string' ->) {}
assert.is.true valid_err_loc -> (check'number|string' ->) nil
assert.is.true valid_err_loc -> (check'number string' ->) {}
assert.is.true valid_err_loc -> (check'?number ?string' ->) 1, 2, 3
assert.is.true valid_err_loc -> (check'number string ...' ->) {}
prevent_tail_call_so_that_this_shows_up_in_backtrace = 1