131 lines
3.9 KiB
Lua
131 lines
3.9 KiB
Lua
local cache = require("cache")
|
|
local sql = require("lsqlite3")
|
|
local db = require("db")
|
|
local queries = require("queries")
|
|
local util = require("util")
|
|
local tags = require("tags")
|
|
require("global")
|
|
|
|
local stmnt_tags_get, stmnt_stories_get
|
|
|
|
local oldconfigure = configure
|
|
function configure(...)
|
|
stmnt_tags_get = db.sqlassert(db.conn:prepare(queries.select_suggest_tags))
|
|
stmnt_stories_get = db.sqlassert(db.conn:prepare(queries.select_site_index))
|
|
return oldconfigure(...)
|
|
end
|
|
|
|
--[[
|
|
When a user is typing in the "tags" editbox when posting a story, suggest
|
|
tags for them to include based on what they've typed so far.
|
|
]]
|
|
local function suggest_tags(req,data)
|
|
--[[
|
|
Prevent a malicious user from injecting '%' into the string
|
|
we're searching for, potentially causing a DoS with a
|
|
sufficiently backtrack-ey search/tag combination.
|
|
]]
|
|
assert(data:match("^[a-zA-Z0-9,%s-]+$"),string.format("Bad characters in tag: %q",data))
|
|
stmnt_tags_get:bind_names{
|
|
match = data .. "%"
|
|
}
|
|
local tags = {data}
|
|
for tag in stmnt_tags_get:rows() do
|
|
table.insert(tags,tag[1])
|
|
end
|
|
stmnt_tags_get:reset()
|
|
http_response_header(req,"Content-Type","text/plain")
|
|
http_response(req,200,table.concat(tags,";"))
|
|
end
|
|
|
|
--[[
|
|
A poor mans json builder, since I don't need one big enough to pull in a
|
|
dependency for it (yet)
|
|
]]
|
|
local function poor_json(builder, ltbl)
|
|
local function write_bool(builder,bool)
|
|
table.insert(builder,bool and "true" or "false")
|
|
end
|
|
local function write_number(builder,num)
|
|
local number
|
|
if num % 1 == 0 then
|
|
num = string.format("%d",num)
|
|
else
|
|
num = string.format("%f",num)
|
|
end
|
|
table.insert(builder,num)
|
|
end
|
|
local function write_string(builder,s)
|
|
table.insert(builder, string.format("%q",s))
|
|
end
|
|
local function write_array(builder,tbl)
|
|
table.insert(builder,"[")
|
|
for _,item in ipairs(tbl) do
|
|
write_string(builder,item)
|
|
table.insert(builder,",")
|
|
end
|
|
if #tbl > 0 then
|
|
table.remove(builder,#builder) -- Remove the last comma
|
|
end
|
|
table.insert(builder,"]")
|
|
end
|
|
local lua_to_json = {
|
|
boolean = write_bool,
|
|
number = write_number,
|
|
string = write_string,
|
|
table = write_array
|
|
}
|
|
table.insert(builder,"{")
|
|
for k,v in pairs(ltbl) do
|
|
assert(type(k) == "string", "Field was not a string, was: " .. type(k))
|
|
table.insert(builder,string.format("%q",k))
|
|
table.insert(builder,":")
|
|
assert(lua_to_json[type(v)], "Unknown type for json:" .. type(v) .. " at " .. k)
|
|
lua_to_json[type(v)](builder,v)
|
|
table.insert(builder,",")
|
|
end
|
|
table.remove(builder,#builder) -- Remove the last comma before closing object
|
|
table.insert(builder,"}")
|
|
table.insert(builder,",") -- Can't do this on the same line as above
|
|
-- we need to remove the last comma, but not }
|
|
end
|
|
|
|
local function get_stories(req,data)
|
|
local nstories = tonumber(data)
|
|
stmnt_stories_get:bind_names{offset=nstories}
|
|
local builder = setmetatable({'{"stories":['},table)
|
|
for id, title, anon, time, author, hits, ncomments in db.sql_rows(stmnt_stories_get) do
|
|
local story = {
|
|
url = util.encode_id(id),
|
|
title = title,
|
|
isanon = tonumber(anon) == 1,
|
|
posted = os.date("%B %d %Y",tonumber(time)),
|
|
author = author,
|
|
tags = tags.get(id),
|
|
hits = hits,
|
|
ncomments = ncomments
|
|
}
|
|
poor_json(builder,story)
|
|
end
|
|
table.remove(builder,#builder) -- Remove last comma before closing list
|
|
table.insert(builder,"]}")
|
|
stmnt_stories_get:reset()
|
|
http_response_header(req,"Content-Type","text/plain")
|
|
http_response(req,200,table.concat(builder))
|
|
end
|
|
|
|
local api_points = {}
|
|
local function register_api(call,func)
|
|
api_points[call] = func
|
|
end
|
|
register_api("suggest",suggest_tags)
|
|
register_api("stories",get_stories)
|
|
local function api_get(req)
|
|
http_request_populate_qs(req)
|
|
local call = assert(http_argument_get_string(req,"call"))
|
|
local data = assert(http_argument_get_string(req,"data"))
|
|
assertf(api_points[call], "Unknown api endpoint: %s", call)
|
|
api_points[call](req,data)
|
|
end
|
|
return api_get
|