diff --git a/src/lua/cache.lua b/src/lua/cache.lua index 244d6cf..8600129 100644 --- a/src/lua/cache.lua +++ b/src/lua/cache.lua @@ -9,6 +9,7 @@ local sql = require("lsqlite3") local queries = require("queries") local util = require("util") +local db = require("db") local ret = {} @@ -16,11 +17,11 @@ local stmnt_cache, stmnt_insert_cache, stmnt_dirty_cache local oldconfigure = configure function configure(...) - local cache = util.sqlassert(sql.open_memory()) + local cache = db.sqlassert(sql.open_memory()) ret.cache = cache -- Expose db for testing --A cache table to store rendered pages that do not need to be --rerendered. In theory this could OOM the program eventually and start - --swapping to disk. TODO: fixme + --swapping to disk. TODO assert(cache:exec([[ CREATE TABLE IF NOT EXISTS cache ( path TEXT PRIMARY KEY, @@ -55,7 +56,7 @@ end --Render a page, with cacheing. If you need to dirty a cache, call dirty_cache() function ret.render(pagename,callback) stmnt_cache:bind_names{path=pagename} - local err = util.do_sql(stmnt_cache) + local err = db.do_sql(stmnt_cache) if err == sql.DONE then stmnt_cache:reset() --page is not cached @@ -73,7 +74,7 @@ function ret.render(pagename,callback) path=pagename, data=text, } - err = util.do_sql(stmnt_insert_cache) + err = db.do_sql(stmnt_insert_cache) if err == sql.ERROR or err == sql.MISUSE then error("Failed to update cache for page " .. pagename) end @@ -81,11 +82,13 @@ function ret.render(pagename,callback) return text end +-- Dirty a cached page, causing it to be re-rendered the next time it's +-- requested. Doesn't actually delete it or anything, just sets it's dirty bit function ret.dirty(url) stmnt_dirty_cache:bind_names{ path = url } - util.do_sql(stmnt_dirty_cache) + db.do_sql(stmnt_dirty_cache) stmnt_dirty_cache:reset() end diff --git a/src/lua/db.lua b/src/lua/db.lua index 304e20f..6d4564c 100644 --- a/src/lua/db.lua +++ b/src/lua/db.lua @@ -9,8 +9,85 @@ local util = require("util") local config = require("config") local db = {} + +--[[ +Runs an sql query and receives the 3 arguments back, prints a nice error +message on fail, and returns true on success. +]] +function db.sqlassert(r, errcode, err) + if not r then + error(string.format("%d: %s",errcode, err)) + end + return r +end + +--[[ +Continuously tries to perform an sql statement until it goes through +]] +function db.do_sql(stmnt) + if not stmnt then error("No statement",2) end + local err + local i = 0 + repeat + err = stmnt:step() + if err == sql.BUSY then + i = i + 1 + coroutine.yield() + end + until(err ~= sql.BUSY or i > 10) + assert(i < 10, "Database busy") + return err +end + +--[[ +Provides an iterator that loops over results in an sql statement +or throws an error, then resets the statement after the loop is done. +]] +function db.sql_rows(stmnt) + if not stmnt then error("No statement",2) end + local err + return function() + err = stmnt:step() + if err == sql.BUSY then + coroutine.yield() + elseif err == sql.ROW then + return unpack(stmnt:get_values()) + elseif err == sql.DONE then + stmnt:reset() + return nil + else + stmnt:reset() + local msg = string.format( + "SQL Iteration failed: %s : %s\n%s", + tostring(err), + db.conn:errmsg(), + debug.traceback() + ) + log(LOG_CRIT,msg) + error(msg) + end + end +end + +--[[ +Binds an argument to as statement with nice error reporting on failure +stmnt :: sql.stmnt - the prepared sql statemnet +call :: string - a string "bind" or "bind_blob" +position :: number - the argument position to bind to +data :: string - The data to bind +]] +function db.sqlbind(stmnt,call,position,data) + assert(call == "bind" or call == "bind_blob","Bad bind call, call was:" .. call) + local f = stmnt[call](stmnt,position,data) + if f ~= sql.OK then + error(string.format("Failed call %s(%d,%q): %s", call, position, data, db.conn:errmsg()),2) + end +end + + + local oldconfigure = configure -db.conn = util.sqlassert(sql.open(config.db)) +db.conn = db.sqlassert(sql.open(config.db)) function configure(...) --Create sql tables diff --git a/src/lua/endpoints/api_get.lua b/src/lua/endpoints/api_get.lua index b9f9796..98a8174 100644 --- a/src/lua/endpoints/api_get.lua +++ b/src/lua/endpoints/api_get.lua @@ -8,7 +8,7 @@ local stmnt_tags_get local oldconfigure = configure function configure(...) - stmnt_tags_get = util.sqlassert(db.conn:prepare(queries.select_suggest_tags)) + stmnt_tags_get = db.sqlassert(db.conn:prepare(queries.select_suggest_tags)) return oldconfigure(...) end diff --git a/src/lua/endpoints/bio_get.lua b/src/lua/endpoints/bio_get.lua index 632099d..18fac1d 100644 --- a/src/lua/endpoints/bio_get.lua +++ b/src/lua/endpoints/bio_get.lua @@ -37,7 +37,7 @@ local function bio_edit_get(req) stmnt_bio:bind_names{ authorid = authorid } - local err = util.do_sql(stmnt_bio) + local err = db.do_sql(stmnt_bio) if err == sql.DONE then --No rows, we're logged in but an author with our id doesn't --exist? Something has gone wrong. @@ -56,10 +56,13 @@ found, please report this error. end assert(err == sql.ROW) local data = stmnt_bio:get_values() - local bio = zlib.decompress(data[1]) + local bio_text = data[1] + if data[1] ~= "" then + bio_text = zlib.decompress(data[1]) + end stmnt_bio:reset() ret = pages.edit_bio{ - text = bio, + text = bio_text, user = author, domain = config.domain, } diff --git a/src/lua/endpoints/bio_post.lua b/src/lua/endpoints/bio_post.lua index a81b8dd..2147356 100644 --- a/src/lua/endpoints/bio_post.lua +++ b/src/lua/endpoints/bio_post.lua @@ -26,18 +26,19 @@ local function edit_bio(req) local author, author_id = session.get(req) http_request_populate_post(req) - local text = assert(http_argument_get_string(req,"text")) + local text = http_argument_get_string(req,"text") or "" local parsed = parsers.plain(text) -- Make sure the plain parser can deal with it, even though we don't store this result. local compr_raw = zlib.compress(text) local compr = zlib.compress(parsed) - assert(stmnt_update_bio:bind_blob(1,compr_raw) == sql.OK) - assert(stmnt_update_bio:bind(2, author_id) == sql.OK) - if util.do_sql(stmnt_update_bio) ~= sql.DONE then + db.sqlbind(stmnt_update_bio, "bind_blob", 1,compr_raw) + db.sqlbind(stmnt_update_bio, "bind", 2, author_id) + if db.do_sql(stmnt_update_bio) ~= sql.DONE then stmnt_update_bio:reset() error("Faled to update biography") end + stmnt_update_bio:reset() local loc = string.format("https://%s.%s",author,config.domain) -- Dirty the cache for the author's index, the only place where the bio is displayed. cache.dirty(string.format("%s.%s",author,config.domain)) diff --git a/src/lua/endpoints/claim_post.lua b/src/lua/endpoints/claim_post.lua index a83df10..b83f524 100644 --- a/src/lua/endpoints/claim_post.lua +++ b/src/lua/endpoints/claim_post.lua @@ -16,7 +16,7 @@ local stmnt_author_create local oldconfigure = configure function configure(...) - stmnt_author_create = util.sqlassert(db.conn:prepare(queries.insert_author)) + stmnt_author_create = db.sqlassert(db.conn:prepare(queries.insert_author)) return oldconfigure(...) end @@ -46,7 +46,7 @@ local function claim_post(req) } stmnt_author_create:bind_blob(2,salt) stmnt_author_create:bind_blob(3,hash) - local err = util.do_sql(stmnt_author_create) + local err = db.do_sql(stmnt_author_create) if err == sql.DONE then log(LOG_INFO,"Account creation successful:" .. name) --We sucessfully made the new author diff --git a/src/lua/endpoints/delete_post.lua b/src/lua/endpoints/delete_post.lua index f548b18..d0f8b44 100644 --- a/src/lua/endpoints/delete_post.lua +++ b/src/lua/endpoints/delete_post.lua @@ -37,7 +37,7 @@ local function delete_post(req) postid = storyid, authorid = authorid } - local err = util.do_sql(stmnt_delete) + local err = db.do_sql(stmnt_delete) if err ~= sql.DONE then log(LOG_DEBUG,string.format("Failed to delete: %d:%s",err, db.conn:errmsg())) http_response(req,500,pages.error{ diff --git a/src/lua/endpoints/download_get.lua b/src/lua/endpoints/download_get.lua index 1c63bd3..c2b1af2 100644 --- a/src/lua/endpoints/download_get.lua +++ b/src/lua/endpoints/download_get.lua @@ -25,7 +25,7 @@ local function download_get(req) stmnt_download:bind_names{ postid = story_id } - local err = util.do_sql(stmnt_download) + local err = db.do_sql(stmnt_download) if err == sql.DONE then --No rows, story not found http_response(req,404,pages.nostory{path=story}) diff --git a/src/lua/endpoints/edit_get.lua b/src/lua/endpoints/edit_get.lua index 48c6602..7f0a174 100644 --- a/src/lua/endpoints/edit_get.lua +++ b/src/lua/endpoints/edit_get.lua @@ -32,7 +32,7 @@ local function edit_get(req) postid = story_id, authorid = authorid } - local err = util.do_sql(stmnt_edit) + local err = db.do_sql(stmnt_edit) if err == sql.DONE then --No rows, we're probably not the owner (it might --also be because there's no such story) diff --git a/src/lua/endpoints/edit_post.lua b/src/lua/endpoints/edit_post.lua index b3505ee..f604a19 100644 --- a/src/lua/endpoints/edit_post.lua +++ b/src/lua/endpoints/edit_post.lua @@ -38,7 +38,7 @@ local function edit_post(req) stmnt_author_of:bind_names{ id = storyid } - local err = util.do_sql(stmnt_author_of) + local err = db.do_sql(stmnt_author_of) if err ~= sql.ROW then stmnt_author_of:reset() local msg = string.format("No author found for story: %d", storyid) @@ -66,14 +66,14 @@ local function edit_post(req) assert(stmnt_update_raw:bind_blob(1,compr_raw) == sql.OK) assert(stmnt_update_raw:bind(2,markup) == sql.OK) assert(stmnt_update_raw:bind(3,storyid) == sql.OK) - assert(util.do_sql(stmnt_update_raw) == sql.DONE, "Failed to update raw") + assert(db.do_sql(stmnt_update_raw) == sql.DONE, "Failed to update raw") stmnt_update_raw:reset() assert(stmnt_update:bind(1,title) == sql.OK) assert(stmnt_update:bind_blob(2,compr) == sql.OK) assert(stmnt_update:bind(3,pasteas == "anonymous" and 1 or 0) == sql.OK) assert(stmnt_update:bind(4,unlisted) == sql.OK) assert(stmnt_update:bind(5,storyid) == sql.OK) - assert(util.do_sql(stmnt_update) == sql.DONE, "Failed to update text") + assert(db.do_sql(stmnt_update) == sql.DONE, "Failed to update text") stmnt_update:reset() tagslib.set(storyid,tags) local id_enc = util.encode_id(storyid) @@ -81,7 +81,7 @@ local function edit_post(req) local loc = string.format("https://%s/%s",config.domain,id_enc) if unlisted then stmnt_hash:bind_names{id=storyid} - local err = util.do_sql(stmnt_hash) + local err = db.do_sql(stmnt_hash) if err ~= sql.ROW then error("Failed to get a post's hash while trying to make it unlisted") end diff --git a/src/lua/endpoints/index_get.lua b/src/lua/endpoints/index_get.lua index 88b16ae..479db20 100644 --- a/src/lua/endpoints/index_get.lua +++ b/src/lua/endpoints/index_get.lua @@ -9,6 +9,7 @@ local pages = require("pages") local libtags = require("tags") local session = require("session") local parsers = require("parsers") +local zlib = require("zlib") local stmnt_index, stmnt_author, stmnt_author_bio @@ -27,7 +28,7 @@ local function get_site_home(req, loggedin) log(LOG_DEBUG,"Cache miss, rendering site index") stmnt_index:bind_names{} local latest = {} - for idr, title, iar, dater, author, hits in util.sql_rows(stmnt_index) do + for idr, title, iar, dater, author, hits in db.sql_rows(stmnt_index) do table.insert(latest,{ url = util.encode_id(idr), title = title, @@ -45,12 +46,11 @@ local function get_site_home(req, loggedin) } end local function get_author_home(req, loggedin) - --print("Looking at author home...") local host = http_request_get_host(req) local subdomain = host:match("([^\\.]+)") stmnt_author_bio:bind_names{author=subdomain} - local err = util.do_sql(stmnt_author_bio) local author, authorid = session.get(req) + local err = db.do_sql(stmnt_author_bio) if err == sql.DONE then log(LOG_INFO,"No such author:" .. subdomain) stmnt_author_bio:reset() @@ -63,12 +63,15 @@ local function get_author_home(req, loggedin) error(string.format("Failed to get author %q error: %q",subdomain, tostring(err))) end local data = stmnt_author_bio:get_values() - local bio = parsers.plain(zlib.decompress(data[1])) + local bio_text = data[1] + if data[1] ~= "" then + bio_text = zlib.decompress(data[1]) + end + local bio = parsers.plain(bio_text) stmnt_author_bio:reset() stmnt_author:bind_names{author=subdomain} local stories = {} - for id, title, time, hits, unlisted, hash in util.sql_rows(stmnt_author) do - --print("Looking at:",id,title,time,hits,unlisted) + for id, title, time, hits, unlisted, hash in db.sql_rows(stmnt_author) do if unlisted == 1 and author == subdomain then local url = util.encode_id(id) .. "?pwd=" .. util.encode_unlisted(hash) table.insert(stories,{ diff --git a/src/lua/endpoints/login_post.lua b/src/lua/endpoints/login_post.lua index ae5b750..ea97864 100644 --- a/src/lua/endpoints/login_post.lua +++ b/src/lua/endpoints/login_post.lua @@ -27,7 +27,7 @@ local function login_post(req) name = name } local text - local err = util.do_sql(stmnt_author_acct) + local err = db.do_sql(stmnt_author_acct) if err == sql.ROW then local id, salt, passhash = unpack(stmnt_author_acct:get_values()) stmnt_author_acct:reset() diff --git a/src/lua/endpoints/paste_post.lua b/src/lua/endpoints/paste_post.lua index b51b7a5..d9698c6 100644 --- a/src/lua/endpoints/paste_post.lua +++ b/src/lua/endpoints/paste_post.lua @@ -44,14 +44,14 @@ local function anon_paste(req,ps) log(LOG_DEBUG,string.format("new story: %q, length: %d",ps.title,string.len(ps.text))) local textsha3 = sha3(ps.text .. get_random_bytes(32)) - util.sqlbind(stmnt_paste,"bind_blob",1,ps.text) - util.sqlbind(stmnt_paste,"bind",2,ps.title) - util.sqlbind(stmnt_paste,"bind",3,-1) - util.sqlbind(stmnt_paste,"bind",4,true) - util.sqlbind(stmnt_paste,"bind_blob",5,"") - util.sqlbind(stmnt_paste,"bind",6,ps.unlisted) - util.sqlbind(stmnt_paste,"bind_blob",7,textsha3) - local err = util.do_sql(stmnt_paste) + db.sqlbind(stmnt_paste,"bind_blob",1,ps.text) + db.sqlbind(stmnt_paste,"bind",2,ps.title) + db.sqlbind(stmnt_paste,"bind",3,-1) + db.sqlbind(stmnt_paste,"bind",4,true) + db.sqlbind(stmnt_paste,"bind_blob",5,"") + db.sqlbind(stmnt_paste,"bind",6,ps.unlisted) + db.sqlbind(stmnt_paste,"bind_blob",7,textsha3) + local err = db.do_sql(stmnt_paste) stmnt_paste:reset() if err == sql.DONE then local rowid = stmnt_paste:last_insert_rowid() @@ -62,7 +62,7 @@ local function anon_paste(req,ps) assert(stmnt_raw:bind(1,rowid) == sql.OK) assert(stmnt_raw:bind_blob(2,ps.raw) == sql.OK) assert(stmnt_raw:bind(3,ps.markup) == sql.OK) - err = util.do_sql(stmnt_raw) + err = db.do_sql(stmnt_raw) stmnt_raw:reset() if err ~= sql.DONE then local msg = string.format( @@ -112,9 +112,9 @@ local function author_paste(req,ps) assert(stmnt_paste:bind(3,authorid) == sql.OK) assert(stmnt_paste:bind(4,asanon == "anonymous") == sql.OK) assert(stmnt_paste:bind_blob(5,"") == sql.OK) - util.sqlbind(stmnt_paste,"bind",6,ps.unlisted) - util.sqlbind(stmnt_paste,"bind_blob",7,textsha3) - local err = util.do_sql(stmnt_paste) + db.sqlbind(stmnt_paste,"bind",6,ps.unlisted) + db.sqlbind(stmnt_paste,"bind_blob",7,textsha3) + local err = db.do_sql(stmnt_paste) stmnt_paste:reset() if err == sql.DONE then local rowid = stmnt_paste:last_insert_rowid() @@ -125,7 +125,7 @@ local function author_paste(req,ps) assert(stmnt_raw:bind(1,rowid) == sql.OK) assert(stmnt_raw:bind_blob(2,ps.raw) == sql.OK) assert(stmnt_raw:bind(3,ps.markup) == sql.OK) - err = util.do_sql(stmnt_raw) + err = db.do_sql(stmnt_raw) stmnt_raw:reset() if err ~= sql.DONE then local msg = string.format( diff --git a/src/lua/endpoints/read_get.lua b/src/lua/endpoints/read_get.lua index 67ac4a7..31a78fe 100644 --- a/src/lua/endpoints/read_get.lua +++ b/src/lua/endpoints/read_get.lua @@ -28,7 +28,7 @@ local function add_view(storyid) stmnt_update_views:bind_names{ id = storyid } - local err = util.do_sql(stmnt_update_views) + local err = db.do_sql(stmnt_update_views) assert(err == sql.DONE, "Failed to update view counter:"..tostring(err)) stmnt_update_views:reset() end @@ -42,7 +42,7 @@ local function populate_ps_story(req,ps) stmnt_read:bind_names{ id = ps.storyid, } - local err = util.do_sql(stmnt_read) + local err = db.do_sql(stmnt_read) if err == sql.DONE then --We got no story stmnt_read:reset() @@ -81,7 +81,7 @@ local function get_comments(req,ps) id = ps.storyid } local comments = {} - for com_author, com_isanon, com_text in util.sql_rows(stmnt_comments) do + for com_author, com_isanon, com_text in db.sql_rows(stmnt_comments) do table.insert(comments,{ author = com_author, isanon = com_isanon == 1, --int to boolean diff --git a/src/lua/endpoints/read_post.lua b/src/lua/endpoints/read_post.lua index fed4730..135189a 100644 --- a/src/lua/endpoints/read_post.lua +++ b/src/lua/endpoints/read_post.lua @@ -38,7 +38,7 @@ local function read_post(req) isanon = isanon, comment_text = comment_text, } - local err = util.do_sql(stmnt_comment_insert) + local err = db.do_sql(stmnt_comment_insert) stmnt_comment_insert:reset() if err ~= sql.DONE then http_response(req,500,"Internal error, failed to post comment. Go back and try again.") diff --git a/src/lua/global.lua b/src/lua/global.lua new file mode 100644 index 0000000..2d99571 --- /dev/null +++ b/src/lua/global.lua @@ -0,0 +1,9 @@ +-- Various global functions to cause less typing. + + +function assertf(bool, fmt, ...) + fmt = fmt or "Assetion Failed" + if not bool then + error(string.format(fmt,...),2) + end +end diff --git a/src/lua/session.lua b/src/lua/session.lua index 45ed06c..e97834d 100644 --- a/src/lua/session.lua +++ b/src/lua/session.lua @@ -28,7 +28,7 @@ function session.get(req) stmnt_get_session:bind_names{ key = sessionid } - local err = util.do_sql(stmnt_get_session) + local err = db.do_sql(stmnt_get_session) if err ~= sql.ROW then stmnt_get_session:reset() return nil, "No such session by logged in users" @@ -57,7 +57,7 @@ function session.start(who) sessionid = session, authorid = who } - local err = util.do_sql(stmnt_insert_session) + local err = db.do_sql(stmnt_insert_session) stmnt_insert_session:reset() assert(err == sql.DONE) return session @@ -71,7 +71,7 @@ function session.finish(who,sessionid) authorid = who, sessionid = sessionid } - local err = util.do_sql(stmnt_delete_session) + local err = db.do_sql(stmnt_delete_session) stmnt_delete_session:reset() assert(err == sql.DONE) return true diff --git a/src/lua/tags.lua b/src/lua/tags.lua index b922579..ee8e54e 100644 --- a/src/lua/tags.lua +++ b/src/lua/tags.lua @@ -41,13 +41,13 @@ end function tags.set(storyid,tags) assert(stmnt_drop_tags:bind_names{postid = storyid} == sql.OK) - util.do_sql(stmnt_drop_tags) + db.do_sql(stmnt_drop_tags) stmnt_drop_tags:reset() local err for _,tag in pairs(tags) do assert(stmnt_ins_tag:bind(1,storyid) == sql.OK) assert(stmnt_ins_tag:bind(2,tag) == sql.OK) - err = util.do_sql(stmnt_ins_tag) + err = db.do_sql(stmnt_ins_tag) stmnt_ins_tag:reset() end if err ~= sql.DONE then diff --git a/src/lua/util.lua b/src/lua/util.lua index 602246b..786e343 100644 --- a/src/lua/util.lua +++ b/src/lua/util.lua @@ -4,80 +4,6 @@ local config = require("config") local types = require("types") local util = {} ---[[ -Runs an sql query and receives the 3 arguments back, prints a nice error -message on fail, and returns true on success. -]] -function util.sqlassert(r, errcode, err) - if not r then - error(string.format("%d: %s",errcode, err)) - end - return r -end - ---[[ -Continuously tries to perform an sql statement until it goes through -]] -function util.do_sql(stmnt) - if not stmnt then error("No statement",2) end - local err - local i = 0 - repeat - err = stmnt:step() - if err == sql.BUSY then - i = i + 1 - coroutine.yield() - end - until(err ~= sql.BUSY or i > 10) - assert(i < 10, "Database busy") - return err -end - ---[[ -Provides an iterator that loops over results in an sql statement -or throws an error, then resets the statement after the loop is done. -]] -function util.sql_rows(stmnt) - if not stmnt then error("No statement",2) end - local err - return function() - err = stmnt:step() - if err == sql.BUSY then - coroutine.yield() - elseif err == sql.ROW then - return unpack(stmnt:get_values()) - elseif err == sql.DONE then - stmnt:reset() - return nil - else - stmnt:reset() - local msg = string.format( - "SQL Iteration failed: %s : %s\n%s", - tostring(err), - db.conn:errmsg(), - debug.traceback() - ) - log(LOG_CRIT,msg) - error(msg) - end - end -end - ---[[ -Binds an argument to as statement with nice error reporting on failure -stmnt :: sql.stmnt - the prepared sql statemnet -call :: string - a string "bind" or "bind_blob" -position :: number - the argument position to bind to -data :: string - The data to bind -]] -function util.sqlbind(stmnt,call,position,data) - assert(call == "bind" or call == "bind_blob","Bad bind call, call was:" .. call) - local f = stmnt[call](stmnt,position,data) - if f ~= sql.OK then - error(string.format("Failed to %s at %d with %q: %s", call, position, data, db.conn:errmsg()),2) - end -end - --see https://perishablepress.com/stop-using-unsafe-characters-in-urls/ --no underscore because we use that for our operative pages local url_characters = diff --git a/src/pages/edit_bio.etlua b/src/pages/edit_bio.etlua index eca377e..28de155 100644 --- a/src/pages/edit_bio.etlua +++ b/src/pages/edit_bio.etlua @@ -33,7 +33,9 @@

- +
+ +