smr/src/lua/util.lua

243 lines
6.1 KiB
Lua

--[[ md
@name lua/util
Various utilities that aren't big enough for their own module, but are still
used in more than one place.
]]
local config = require("config")
local db = require("db")
local queries = require("queries")
local util = {}
local stmnt_comments
local oldconfigure = configure
function configure(...)
stmnt_comments = assert(db.conn:prepare(queries.select_comments))
return oldconfigure(...)
end
--[[ md
@name doc/url_spec
URLs generated from smr use letters and numbers to encode a monotonically
increasing post id into a url that can easily be shared (and ends up
considerably shorter). The characters used in url generation are:
[a-z][A-Z][0-9], and numbers are encoded to use the second available 1-character
permuation, then the first available 2-character permutation, and so on.
For example, the first post is encoded as 'b', the second as 'c', the thrid
as 'd', and so on. The off-by-one nature is to simplify implementation of
2-character and 3-character combinations with Lua's 1-indexed arrays.
see https://perishablepress.com/stop-using-unsafe-characters-in-urls/
no underscore because we use that for our operative pages
A set of legacy characters that are no longer in use (because they were invalid
to use in URL's) is also defined, but unused as long as
{{config/legacy_url_cutoff}} is set to 0.
]]
local url_characters =
[[abcdefghijklmnopqrstuvwxyz]]..
[[ABCDEFGHIJKLMNOPQRSTUVWXYZ]]..
[[0123456789]]
local url_characters_legacy =
url_characters ..
[[$-+!*'(),]]
local function str2set(str)
local tbl = {}
for i = 1, #str do
tbl[string.sub(str,i,i)] = i
end
return tbl
end
local url_characters_rev = str2set(url_characters)
local url_characters_rev_legacy = str2set(url_characters_legacy)
--[[ md
@name lua/util/encode_id
Encode a number to a shorter HTML-safe url path. Url paths are generated
according to the {{doc/url_spec}
]]
function util.encode_id(number)
local result = {}
repeat
local pos = (number % #url_characters) + 1
number = math.floor(number / #url_characters)
table.insert(result,string.sub(url_characters,pos,pos))
until number == 0
return table.concat(result)
end
--[[
Legacy code, try to encode with invalid characters in the url first
]]
local new_encode = util.encode_id
function util.encode_id(number)
if number >= config.legacy_url_cutoff then
return new_encode(number)
else
local result = {}
repeat
local pos = (number % #url_characters_legacy) + 1
number = math.floor(number / #url_characters_legacy)
table.insert(result,string.sub(url_characters_legacy,pos,pos))
until number == 0
return table.concat(result)
end
end
--[[
Given a short HTML-safe url path, convert it to a storyid
]]
function util.decode_id(s)
local res, id = pcall(function()
local n = 0
for i = 1,string.len(s) do
local char = string.sub(s,i,i)
local pos = url_characters_rev[char] - 1
n = n + (pos*math.pow(#url_characters,i-1))
end
return n
end)
if res then
return id
else
return false,"Failed to decode id:" .. s
end
end
--[[
Legacy code, try to decode with invalid characters in the url first
]]
local new_decode = util.decode_id
function util.decode_id(s)
local res, id = pcall(function()
local n = 0
for i = 1,string.len(s) do
local char = string.sub(s,i,i)
local pos = url_characters_rev_legacy[char] - 1
n = n + (pos * math.pow(#url_characters_legacy,i-1))
end
return n
end)
if res then
if id > config.legacy_url_cutoff then
return new_decode(s)
else
return id
end
else
return false,"Failed to decode id:" .. s
end
end
--arbitary data to hex encoded string
function util.encode_unlisted(str)
assert(type(str) == "string","Tried to encode something not a string:" .. type(str))
local safe = {}
for i = 1,#str do
local byte = str:byte(i)
table.insert(safe,string.format("%02x",byte))
end
return table.concat(safe)
end
--hex encoded string to arbitrary data
function util.decode_unlisted(str)
local output = {}
for byte in str:gmatch("%x%x") do
table.insert(output, string.char(tonumber(byte,16)))
end
return table.concat(output)
end
--[[
Parses a semicolon seperated string into it's parts:
1. seperates by semicolon
2. trims whitespace
3. lowercases
4. capitalizes the first letter.
Returns an array of zero or more strings.
There is no blank tag, parsing "one;two;;three" will yield
{"one","two","three"}
]]
function util.parse_tags(str)
local tags = {}
for tag in string.gmatch(str,"([^;]+)") do
assert(tag, "Found a nil or false tag in:" .. str)
local tag_fmt = tag:match("%s*(.*)%s*"):lower():gsub("^.",string.upper)
assert(tag_fmt, "After processing tag:" .. tag .. " it was falsey.")
if string.len(tag_fmt) > 0 then
table.insert(tags, tag_fmt)
end
end
return tags
end
--[[
Get the comments for a story
Comments are a table with the structure:
comment :: table {
author :: string - The author's text name
isanon :: boolean - True if the author is anon (author string will be "Anonymous")
text :: string - The text of the comment
}
]]
function util.get_comments(sid)
stmnt_comments:bind_names{id = sid}
local comments = {}
for com_author, com_isanon, com_text in db.sql_rows(stmnt_comments) do
table.insert(comments,{
author = com_author,
isanon = com_isanon == 1, --int to boolean
text = com_text
})
end
return comments
end
if config.debugging then
function util.checktypes(...)
local args = {...}
if #args == 1 then
args = table.unpack(args)
end
assert(
#args % 3 == 0,
"Arguments to checktypes() must be triplets of " ..
"<variable>, <lua type>, <type check function> "
)
for i = 1,#args,3 do
local var, ltype, veri_f = args[i+0], args[i+1], args[i+2]
assert(
type(var) == ltype,
string.format(
"Expected argument %d (%q) to be type %s, but was %s",
i/3
)
)
if veri_f then
assert(veri_f(var))
end
end
end
else
function util.checktypes()
end
end
local function decodeentity(capture)
return string.char(tonumber(capture,16)) --Decode base 16 and conver to character
end
function util.decodeentities(str)
return string.gsub(str,"%%(%x%x)",decodeentity)
end
return util