256 lines
6.5 KiB
Lua
256 lines
6.5 KiB
Lua
|
|
local config = require("config")
|
|
config.db = "data/unittest.db"
|
|
local mock = {}
|
|
local env = {}
|
|
mock.env = env
|
|
--Mirror print prior to lua 5.4
|
|
--local oldprint = print
|
|
local ntostring
|
|
|
|
-- Modules that get required lazily
|
|
local login_post
|
|
local fuzzy
|
|
local claim_post
|
|
local session
|
|
print_table= function(...)
|
|
print("Print called")
|
|
local args = {...}
|
|
local mapped_args = {}
|
|
for k,v in ipairs(args) do
|
|
print("mapping",v)
|
|
mapped_args[k] = ntostring(v)
|
|
end
|
|
print(table.concat(mapped_args,"\t"))
|
|
end
|
|
|
|
local tables_called = {}
|
|
function ntostring(arg)
|
|
io.stdout:write("Calling tostring with:",tostring(arg),"\n")
|
|
if type(arg) ~= "table" then
|
|
return tostring(arg)
|
|
end
|
|
local function tbl_to_string(tbl,indent)
|
|
if tables_called[tbl] then
|
|
return tostring(tbl)
|
|
end
|
|
tables_called[tbl] = true
|
|
if type(tbl) ~= "table" then
|
|
error("tbl_to_string must be called with a table, got a " .. type(tbl))
|
|
end
|
|
local lines = {string.rep("\t",indent) .. "{"}
|
|
for k, v in pairs(tbl) do
|
|
local kv = {}
|
|
for i,n in pairs{k,v} do
|
|
if type(n) == "table" then
|
|
kv[i] = string.format("%q",tbl_to_string(n,indent+1))
|
|
else
|
|
kv[i] = string.format("%q",tostring(n))
|
|
end
|
|
end
|
|
table.insert(
|
|
lines,
|
|
string.rep("\t",indent+1) .. kv[1] .. ":" .. kv[2]
|
|
)
|
|
end
|
|
table.insert(lines,string.rep("\t",indent) .. "}")
|
|
return table.concat(lines,"\n")
|
|
end
|
|
--It's a table
|
|
local ret = tbl_to_string(arg,0)
|
|
tables_called = {}
|
|
return ret
|
|
end
|
|
|
|
local smr_mock_env = {
|
|
--An empty function that gets called to set up databases and do other
|
|
--startup-time stuff, runs once for each worker process.
|
|
configure = spy.new(function(...) end),
|
|
http_request_get_host = spy.new(function(req) return req.host or "test.host" end),
|
|
http_request_get_path = spy.new(function(req) return req.path or "/" end),
|
|
http_request_populate_qs = spy.new(function(req) req.qs_populated = true end),
|
|
http_request_populate_post = spy.new(function(req) req.post_populated = true end),
|
|
http_populate_multipart_form = spy.new(function(req)
|
|
req.post_populated = true
|
|
req.multipart_form_populated = true
|
|
end),
|
|
http_argument_get_string = spy.new(function(req,str)
|
|
assert(req.args,"requests should have a .args table")
|
|
assert(
|
|
req.method == "GET" and req.qs_populated or
|
|
req.method == "POST" and req.post_populated,[[
|
|
http_argument_get_string() can only be called after
|
|
the appropriate populate method has been called, either
|
|
http_request_populate_qs(req) or
|
|
http_request_populate_post(req)]]
|
|
)
|
|
return req.args[str]
|
|
end),
|
|
http_file_get = spy.new(function(req,filename)
|
|
assert(req.multipart_form_populated,[[
|
|
http_file_get() can only be called after the approriate
|
|
populate method has been called. (http_populate_multipart_form())
|
|
]])
|
|
return req.file["pass"]
|
|
end),
|
|
http_response = spy.new(function(req,errcode,html)
|
|
req.responsecode = errcode
|
|
req.response = html
|
|
end),
|
|
http_response_header = spy.new(function(req,name,value)
|
|
req.response_headers = req.response_headers or {}
|
|
req.response_headers[name] = value
|
|
end),
|
|
http_method_text = spy.new(function(req) return req.method end),
|
|
http_populate_cookies = spy.new(function(req)
|
|
req.cookies_populated = true
|
|
req.cookies = req.cookies or {}
|
|
end),
|
|
http_request_cookie = spy.new(function(req,cookie_name)
|
|
assert(req.cookies_populated,[[
|
|
http_request_cookie() can only be called after
|
|
http_populate_cookies() has been called.
|
|
]])
|
|
return req.cookies[cookie_name]
|
|
end),
|
|
http_response_cookie = spy.new(function(req,name,value) req.cookies = {[name] = value} end),
|
|
log = spy.new(function(priority, message) --[[print(string.format("[LOG %q]: %s",priority,message))]] end),
|
|
--Logging:
|
|
LOG_DEBUG = "debug",
|
|
LOG_INFO = "info",
|
|
LOG_NOTICE = "notice",
|
|
LOG_WARNING = "warning",
|
|
LOG_ERR = "error",
|
|
LOG_CRIT = "critical",
|
|
LOG_ALERT = "alert",
|
|
LOG_EMERG = "emergency",
|
|
sha3 = spy.new(function(message) return "digest" end),
|
|
}
|
|
|
|
local smr_mock_env_m = {
|
|
__index = smr_mock_env,
|
|
__newindex = function(self,key,value)
|
|
local setter = debug.getinfo(2)
|
|
if setter.source ~= "=[C]" and setter.source ~= "@./global.lua" and key ~= "configure" then
|
|
error(string.format(
|
|
"Tried to create a global %q with value %s\n%s",
|
|
key,
|
|
tostring(value),
|
|
debug.traceback()
|
|
),2)
|
|
else
|
|
rawset(self,key,value)
|
|
end
|
|
end
|
|
}
|
|
|
|
local sfmt = string.format
|
|
local string_fmt_override = {
|
|
format = spy.new(function(fmt,...)
|
|
local args = {...}
|
|
for i = 1,#args do
|
|
if args[i] == nil then
|
|
args[i] = "nil"
|
|
end
|
|
end
|
|
table.insert(args,1,fmt)
|
|
return sfmt(unpack(args))
|
|
end)
|
|
}
|
|
setmetatable(string_fmt_override,{__index = string})
|
|
local smr_override_env = {
|
|
--Detour assert so we don't actually perform any checks
|
|
--assert = spy.new(function(bool,msg,level) return bool end),
|
|
--Allow string.format to accept nil as arguments
|
|
--string = string_fmt_override
|
|
}
|
|
|
|
mock.olds = {}
|
|
|
|
function mock.setup()
|
|
setmetatable(_G,smr_mock_env_m)
|
|
for k,v in pairs(smr_override_env) do
|
|
mock.olds[k] = _G[k]
|
|
_G[k] = v
|
|
end
|
|
end
|
|
|
|
function mock.mockdb()
|
|
local config = require("config")
|
|
--config.db = "data/unittest.db"
|
|
config.db = ":memory:"
|
|
assert(os.execute("rm " .. config.db))
|
|
package.loaded.db = nil
|
|
local db = require("db")
|
|
configure()
|
|
end
|
|
|
|
function mock.teardown()
|
|
setmetatable(_G,{})
|
|
for k,v in pairs(mock.olds) do
|
|
_G[k] = v
|
|
end
|
|
end
|
|
|
|
local session_m = {__index = {
|
|
login = function(self, who, pass)
|
|
if not self.args then
|
|
error("Request should have a .args table")
|
|
end
|
|
print("Right before requireing login_post endpoint, self.args is " .. tostring(self.args))
|
|
print("After requireing login_post edpoint, self.args is " .. tostring(self.args))
|
|
self.args.user = who
|
|
self.args.pass = pass
|
|
login_post(self)
|
|
error("TODO")
|
|
end,
|
|
logout = function(self)
|
|
error("TODO")
|
|
end,
|
|
req = function(self, args)
|
|
|
|
end
|
|
}}
|
|
|
|
function mock.session(tbl)
|
|
if login_post == nil then
|
|
login_post = require("endpoints.login_post")
|
|
fuzzy = require("spec.fuzzgen")
|
|
claim_post = require("endpoints.claim_post")
|
|
configure()
|
|
end
|
|
local username = fuzzy.subdomain()
|
|
local claim_req = {
|
|
method = "POST",
|
|
host = "test.host",
|
|
path = "/_claim",
|
|
args = {
|
|
user = username
|
|
}
|
|
}
|
|
claim_post(claim_req)
|
|
local login_req = {
|
|
method = "POST",
|
|
host = "test.host",
|
|
path = "/_login",
|
|
args = {
|
|
user = username
|
|
},
|
|
file = {
|
|
pass = claim_req.response
|
|
}
|
|
}
|
|
login_post(login_req)
|
|
local cookie = login_req.response_headers["set-cookie"]
|
|
local sessionid = cookie:match("session=([^;]+)")
|
|
local req = {
|
|
host = "test.host",
|
|
cookies = {
|
|
session = sessionid
|
|
}
|
|
}
|
|
return req, username
|
|
end
|
|
|
|
return mock
|