smr/src/lua/util.lua

179 lines
4.5 KiB
Lua

local sql = require("lsqlite3")
local util = {}
--[[
Runs an sql query and receives the 3 arguments back, prints a nice error
message on fail, and returns true on success.
]]
function util.sqlassert(...)
local r,errcode,err = ...
if not r then
error(string.format("%d: %s",errcode, err))
end
return r
end
--[[
Continuously tries to perform an sql statement until it goes through
]]
function util.do_sql(stmnt)
if not stmnt then error("No statement",2) end
local err
local i = 0
repeat
err = stmnt:step()
if err == sql.BUSY then
i = i + 1
coroutine.yield()
end
until(err ~= sql.BUSY or i > 10)
assert(i < 10, "Database busy")
return err
end
--[[
Provides an iterator that loops over results in an sql statement
or throws an error, then resets the statement after the loop is done.
]]
function util.sql_rows(stmnt)
if not stmnt then error("No statement",2) end
local err
return function()
err = stmnt:step()
if err == sql.BUSY then
coroutine.yield()
elseif err == sql.ROW then
return unpack(stmnt:get_values())
elseif err == sql.DONE then
stmnt:reset()
return nil
else
stmnt:reset()
local msg = string.format(
"SQL Iteration failed: %s : %s\n%s",
tostring(err),
db.conn:errmsg(),
debug.traceback()
)
log(LOG_CRIT,msg)
error(msg)
end
end
end
--[[
Binds an argument to as statement with nice error reporting on failure
stmnt :: sql.stmnt - the prepared sql statemnet
call :: string - a string "bind" or "bind_blob"
position :: number - the argument position to bind to
data :: string - The data to bind
]]
function util.sqlbind(stmnt,call,position,data)
assert(call == "bind" or call == "bind_blob","Bad bind call, call was:" .. call)
local f = stmnt[call](stmnt,position,data)
if f ~= sql.OK then
error(string.format("Failed to %s at %d with %q: %s", call, position, data, db.conn:errmsg()),2)
end
end
--see https://perishablepress.com/stop-using-unsafe-characters-in-urls/
--no underscore because we use that for our operative pages
local url_characters =
[[abcdefghijklmnopqrstuvwxyz]]..
[[ABCDEFGHIJKLMNOPQRSTUVWXYZ]]..
[[0123456789]]..
[[$-+!*'(),]]
local url_characters_rev = {}
for i = 1,string.len(url_characters) do
url_characters_rev[string.sub(url_characters,i,i)] = i
end
--[[
Encode a number to a shorter HTML-safe url path
]]
function util.encode_id(number)
local result = {}
local charlen = string.len(url_characters)
repeat
local pos = (number % charlen) + 1
number = math.floor(number / charlen)
table.insert(result,string.sub(url_characters,pos,pos))
until number == 0
return table.concat(result)
end
--[[
Given a short HTML-safe url path, convert it to a storyid
]]
function util.decode_id(s)
local res, id = pcall(function()
local n = 0
local charlen = string.len(url_characters)
for i = 1,string.len(s) do
local char = string.sub(s,i,i)
local pos = url_characters_rev[char] - 1
n = n + (pos*math.pow(charlen,i-1))
end
return n
end)
if res then
return id
else
error("Failed to decode id:" .. s)
end
end
--arbitary data to hex encoded string
function util.encode_unlisted(str)
local safe = {}
for i = 1,#str do
local byte = str:byte(i)
table.insert(safe,string.format("%02x",byte))
end
return table.concat(safe)
end
--hex encoded string to arbitrary data
function util.decode_unlisted(str)
print("str was:",str)
local output = {}
for byte in str:gmatch("%x%x") do
table.insert(output, string.char(tonumber(byte,16)))
end
return table.concat(output)
end
--[[
Parses a semicolon seperated string into it's parts:
1. seperates by semicolon
2. trims whitespace
3. lowercases
4. capitalizes the first letter.
Returns an array of zero or more strings.
There is no blank tag, parsing "one;two;;three" will yield
{"one","two","three"}
]]
function util.parse_tags(str)
local tags = {}
for tag in string.gmatch(str,"([^;]+)") do
assert(tag, "Found a nil or false tag in:" .. str)
local tag_trimmed = string.match(tag,"%s*(.*)%s*")
local tag_lower = string.lower(tag_trimmed)
local tag_capitalized = string.gsub(tag_lower,"^.",string.upper)
assert(tag_capitalized, "After processing tag:" .. tag .. " it was falsey.")
if string.len(tag_capitalized) > 0 then
table.insert(tags, tag_capitalized)
end
end
return tags
end
local function decodeentity(capture)
return string.char(tonumber(capture,16)) --Decode base 16 and conver to character
end
function util.decodeentities(str)
return string.gsub(str,"%%(%x%x)",decodeentity)
end
return util