Add a type system

Add a check_types method that can check lua types for correctness.
This commit is contained in:
Robin Malley 2022-02-20 00:17:36 +00:00
parent 87556f77cc
commit 138cf12028
3 changed files with 98 additions and 0 deletions

21
spec/typeing.lua Normal file
View File

@ -0,0 +1,21 @@
--Make sure the type checking works
describe("smr type checking",function()
it("should load without errors",function()
local types = require("types")
end)
it("should not error when an argument is a number",function()
local types = require("types")
local n = 5
assert(types.number(n))
end)
it("should error when an argument is a table",function()
local types = require("types")
local t = {}
assert.has.errors(function()
types.number(t)
end)
end)
end)

47
src/lua/types.lua Normal file
View File

@ -0,0 +1,47 @@
--[[
Type checking, vaguely inspired by Python3's typing module.
]]
local types = {}
function types.positive(arg)
local is_number, err = types.number(arg)
if not is_number then
return false, err
end
if arg < 0 then
return false, string.format("was not positive")
end
return true
end
--Basic lua types
local builtin_types = {
"nil","boolean","number","string","table","function","coroutine","userdata"
}
for _,type_ in pairs(builtin_types) do
types[type_] = function(arg)
local argtype = type(arg)
if not argtype == type_ then
return false, string.format("was not a %s, was a %s",type_,argtype)
end
end
end
function types.matches_pattern(pattern)
return function(arg)
local is_string, err = types.string(arg)
if not is_string then
return false, err
end
if not string.match(arg, pattern) then
return false, string.format(
"Expected %q to match pattern %q, but it did not.",
arg,
pattern
)
end
end
end
return types

View File

@ -222,6 +222,36 @@ function util.parse_tags(str)
return tags
end
if config.debugging then
function util.checktypes(...)
local args = {...}
if #args == 1 then
args = table.unpack(args)
end
assert(
#args % 3 == 0,
"Arguments to checktypes() must be triplets of " ..
"<variable>, <lua type>, <type check function> "
)
for i = 1,#args,3 do
local var, ltype, veri_f = args[i+0], args[i+1], args[i+2]
assert(
type(var) == ltype,
string.format(
"Expected argument %d (%q) to be type %s, but was %s",
i/3
)
)
if veri_f then
assert(veri_f(var))
end
end
end
else
function util.checktypes(...)
end
end
local function decodeentity(capture)
return string.char(tonumber(capture,16)) --Decode base 16 and conver to character
end