diff --git a/automation/include/aegisub/re.moon b/automation/include/aegisub/re.moon index b571b74bd..1b278a16e 100644 --- a/automation/include/aegisub/re.moon +++ b/automation/include/aegisub/re.moon @@ -17,32 +17,79 @@ next = next select = select type = type --- Get the boost::regex binding +bit = require 'bit' +ffi = require 'ffi' +ffi_util = require 'aegisub.ffi' + +ffi.cdef[[ + typedef struct agi_re_flag { + const char *name; + int value; + } agi_re_flag; +]] +regex_flag = ffi.typeof 'agi_re_flag' + +-- Get the boost::eegex binding regex = require 'aegisub.__re_impl' +-- Wrappers to convert returned values from C types to Lua types +search = (re, str, start) -> + return unless start <= str\len() + res = regex.search re, str, str\len(), start + return unless res != nil + first, last = res[0], res[1] + ffi.C.free res + first, last + +replace = (re, replacement, str, max_count) -> + ffi_util.string regex.replace re, replacement, str, str\len(), max_count + +match = (re, str, start) -> + assert start <= str\len() + m = regex.match re, str, str\len(), start + return unless m != nil + ffi.gc m, regex.match_free + +get_match = (m, idx) -> + res = regex.get_match m, idx + return unless res != nil + res[0], res[1] -- Result buffer is owned by match so no need to free + +err_buff = ffi.new 'char *[1]' +compile = (pattern, flags) -> + err_buff[0] = nil + re = regex.compile pattern, flags, err_buff + if err_buff[0] != nil + return ffi.string err_buff[0] + ffi.gc re, regex.regex_free + -- Return the first n elements from ... select_first = (n, a, ...) -> if n == 0 then return a, select_first n - 1, ... +-- Bitwise-or together regex flags passed as arguments to a function +process_flags = (...) -> + flags = 0 + for i = 1, select '#', ... + v = select i, ... + if not ffi.istype regex_flag, v + error 'Flags must follow all non-flag arguments', 3 + flags = bit.bor flags, v.value + flags + -- Extract the flags from ..., bitwise OR them together, and move them to the -- front of ... unpack_args = (...) -> - userdata_start = nil + flags_start = nil for i = 1, select '#', ... v = select i, ... - if type(v) == 'userdata' - userdata_start = i + if ffi.istype regex_flag, v + flags_start = i break - return 0, ... unless userdata_start - - flags = regex.process_flags select userdata_start, ... - if type(flags) == 'string' - error(flags, 3) - - flags, select_first userdata_start - 1, ... - + return 0, ... unless flags_start + process_flags(select flags_start, ...), select_first flags_start - 1, ... -- Typecheck a variable and throw an error if it fails check_arg = (arg, expected_type, argn, func_name, level) -> @@ -108,20 +155,19 @@ class RegEx new: (@_regex, @_level) => - start = 1 gsplit: (str, skip_empty, max_split) => @_check_self! check_arg str, 'string', 2, 'gsplit', @_level if not max_split or max_split <= 0 then max_split = str\len() - start = 1 + start = 0 prev = 1 do_split = () -> if not str or str\len() == 0 then return local first, last if max_split > 0 - first, last = regex.search @_regex, str, start + first, last = search @_regex, str, start if not first or first > str\len() ret = str\sub prev, str\len() @@ -131,7 +177,7 @@ class RegEx ret = str\sub prev, first - 1 prev = last + 1 - start = 1 + if start >= last then start else last + start = if start >= last then start + 1 else last if skip_empty and ret\len() == 0 do_split() @@ -150,12 +196,12 @@ class RegEx @_check_self! check_arg str, 'string', 2, 'gfind', @_level - start = 1 + start = 0 -> - first, last = regex.search(@_regex, str, start) + first, last = search(@_regex, str, start) return unless first - start = if last >= start then last + 1 else start + 1 + start = if last > start then last else start + 1 str\sub(first, last), first, last find: (str) => @@ -176,7 +222,7 @@ class RegEx if type(repl) == 'function' do_replace_fun @, repl, str, max_count elseif type(repl) == 'string' - regex.replace @_regex, repl, str, max_count + replace @_regex, repl, str, max_count else error "Argument 2 to sub should be a string or function, is '#{type(repl)}' (#{repl})", @_level @@ -185,11 +231,11 @@ class RegEx check_arg str, 'string', 2, 'gmatch', @_level start = if start then start - 1 else 0 - match = regex.match @_regex, str, start - i = 1 + m = match @_regex, str, start + i = 0 -> - return unless match - first, last = regex.get_match match, i + return unless m + first, last = get_match m, i return unless first i += 1 @@ -213,7 +259,7 @@ real_compile = (pattern, level, flags, stored_level) -> if pattern == '' error 'Regular expression must not be empty', level + 1 - re = regex.compile pattern, flags + re = compile pattern, flags if type(re) == 'string' error regex, level + 1 @@ -225,25 +271,31 @@ invoke = (str, pattern, fn, flags, ...) -> compiled_regex[fn](compiled_regex, str, ...) -- Generate a static version of a method with arg type checking -gen_wrapper = (impl_name) -> - (str, pattern, ...) -> - check_arg str, 'string', 1, impl_name, 2 - check_arg pattern, 'string', 2, impl_name, 2 - invoke str, pattern, impl_name, unpack_args ... +gen_wrapper = (impl_name) -> (str, pattern, ...) -> + check_arg str, 'string', 1, impl_name, 2 + check_arg pattern, 'string', 2, impl_name, 2 + invoke str, pattern, impl_name, unpack_args ... -- And now at last the actual public API -re = regex.init_flags(re) +do + re = { + compile: (pattern, ...) -> + check_arg pattern, 'string', 1, 'compile', 2 + real_compile pattern, 2, process_flags(...), 2 -re.compile = (pattern, ...) -> - check_arg pattern, 'string', 1, 'compile', 2 - real_compile pattern, 2, regex.process_flags(...), 2 + split: gen_wrapper 'split' + gsplit: gen_wrapper 'gsplit' + find: gen_wrapper 'find' + gfind: gen_wrapper 'gfind' + match: gen_wrapper 'match' + gmatch: gen_wrapper 'gmatch' + sub: gen_wrapper 'sub' + } -re.split = gen_wrapper 'split' -re.gsplit = gen_wrapper 'gsplit' -re.find = gen_wrapper 'find' -re.gfind = gen_wrapper 'gfind' -re.match = gen_wrapper 'match' -re.gmatch = gen_wrapper 'gmatch' -re.sub = gen_wrapper 'sub' + i = 0 + flags = regex.get_flags() + while flags[i].name != nil + re[ffi.string flags[i].name] = flags[i] + i += 1 -re + re diff --git a/libaegisub/lua/modules/re.cpp b/libaegisub/lua/modules/re.cpp index efc42edd2..03283e6aa 100644 --- a/libaegisub/lua/modules/re.cpp +++ b/libaegisub/lua/modules/re.cpp @@ -14,85 +14,72 @@ // // Aegisub Project http://www.aegisub.org/ -#include "libaegisub/lua/utils.h" +#include "libaegisub/lua/ffi.h" +#include "libaegisub/make_unique.h" #include -#include + +using boost::u32regex; +namespace { +// A cmatch with a match range attached to it so that we can return a pointer to +// an int pair without an extra heap allocation each time (LuaJIT can't compile +// ffi calls which return aggregates by value) +struct agi_re_match { + boost::cmatch m; + int range[2]; +}; + +struct agi_re_flag { + const char *name; + int value; +}; +} + +namespace agi { + AGI_DEFINE_TYPE_NAME(u32regex); + AGI_DEFINE_TYPE_NAME(agi_re_match); + AGI_DEFINE_TYPE_NAME(agi_re_flag); +} namespace { -using namespace agi::lua; - -boost::u32regex& get_regex(lua_State *L) { - return get(L, 1, "aegisub.regex"); +using match = agi_re_match; +bool search(u32regex& re, const char *str, size_t len, int start, boost::cmatch& result) { + return u32regex_search(str + start, str + len, result, re, + start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default); } -boost::smatch& get_smatch(lua_State *L) { - return get(L, 1, "aegisub.smatch"); +match *regex_match(u32regex& re, const char *str, size_t len, int start) { + auto result = agi::make_unique(); + if (!search(re, str, len, start, result->m)) + return nullptr; + return result.release(); } -int regex_matches(lua_State *L) { - push_value(L, u32regex_match(check_string(L, 2), get_regex(L))); - return 1; +int *regex_get_match(match& match, size_t idx) { + if (idx > match.m.size() || !match.m[idx].matched) + return nullptr; + match.range[0] = std::distance(match.m.prefix().first, match.m[idx].first + 1); + match.range[1] = std::distance(match.m.prefix().first, match.m[idx].second); + return match.range; } -int regex_match(lua_State *L) { - auto re = get_regex(L); - std::string str = check_string(L, 2); - int start = lua_tointeger(L, 3); +int *regex_search(u32regex& re, const char *str, size_t len, size_t start) { + boost::cmatch result; + if (!search(re, str, len, start, result)) + return nullptr; - auto result = make(L, "aegisub.smatch"); - if (!u32regex_search(str.cbegin() + start, str.cend(), *result, re, - start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default)) - { - lua_pop(L, 1); - lua_pushnil(L); - } - - return 1; + auto ret = static_cast(malloc(sizeof(int) * 2)); + ret[0] = start + result.position() + 1; + ret[1] = start + result.position() + result.length(); + return ret; } -int regex_get_match(lua_State *L) { - auto& match = get_smatch(L); - auto idx = check_uint(L, 2) - 1; - if (idx > match.size() || !match[idx].matched) { - lua_pushnil(L); - return 1; - } - - push_value(L, distance(match.prefix().first, match[idx].first + 1)); - push_value(L, distance(match.prefix().first, match[idx].second)); - return 2; -} - -int regex_search(lua_State *L) { - auto& re = get_regex(L); - auto str = check_string(L, 2); - auto start = check_uint(L, 3) - 1; - argcheck(L, start <= str.size(), 3, "out of bounds"); - boost::smatch result; - if (!u32regex_search(str.cbegin() + start, str.cend(), result, re, - start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default)) - { - lua_pushnil(L); - return 1; - } - - push_value(L, start + result.position() + 1); - push_value(L, start + result.position() + result.length()); - return 2; -} - -int regex_replace(lua_State *L) { - auto& re = get_regex(L); - const auto replacement = check_string(L, 2); - const auto str = check_string(L, 3); - int max_count = check_int(L, 4); - +char *regex_replace(u32regex& re, const char *replacement, const char *str, size_t len, int max_count) { // Can't just use regex_replace here since it can only do one or infinite replacements - auto match = boost::u32regex_iterator(begin(str), end(str), re); - auto end_it = boost::u32regex_iterator(); + auto match = boost::u32regex_iterator(str, str + len, re); + auto end_it = boost::u32regex_iterator(); - auto suffix = begin(str); + auto suffix = str; std::string ret; auto out = back_inserter(ret); @@ -104,95 +91,51 @@ int regex_replace(lua_State *L) { --max_count; } - copy(suffix, end(str), out); - - push_value(L, ret); - return 1; + ret += suffix; + return strdup(ret.c_str()); } -int regex_compile(lua_State *L) { - auto pattern(check_string(L, 1)); - int flags = check_int(L, 2); - auto re = make(L, "aegisub.regex"); - +u32regex *regex_compile(const char *pattern, int flags, char **err) { + auto re = agi::make_unique(); try { *re = boost::make_u32regex(pattern, boost::u32regex::perl | flags); + return re.release(); } catch (std::exception const& e) { - lua_pop(L, 1); - push_value(L, e.what()); - return 1; - // Do the actual triggering of the error in the Lua code as that code - // can report the original call site + *err = strdup(e.what()); + return nullptr; } - - return 1; } -int regex_gc(lua_State *L) { - using boost::u32regex; - get_regex(L).~u32regex(); - return 0; +void regex_free(u32regex *re) { delete re; } +void match_free(match *m) { delete m; } + +const agi_re_flag *get_regex_flags() { + static const agi_re_flag flags[] = { + {"ICASE", boost::u32regex::icase}, + {"NOSUB", boost::u32regex::nosubs}, + {"COLLATE", boost::u32regex::collate}, + {"NEWLINE_ALT", boost::u32regex::newline_alt}, + {"NO_MOD_M", boost::u32regex::no_mod_m}, + {"NO_MOD_S", boost::u32regex::no_mod_s}, + {"MOD_S", boost::u32regex::mod_s}, + {"MOD_X", boost::u32regex::mod_x}, + {"NO_EMPTY_SUBEXPRESSIONS", boost::u32regex::no_empty_expressions}, + {nullptr, 0} + }; + return flags; } - -int smatch_gc(lua_State *L) { - using boost::smatch; - get_smatch(L).~smatch(); - return 0; -} - -int regex_process_flags(lua_State *L) { - int ret = 0; - int nargs = lua_gettop(L); - for (int i = 1; i <= nargs; ++i) { - if (!lua_islightuserdata(L, i)) { - push_value(L, "Flags must follow all non-flag arguments"); - return 1; - } - ret |= (int)(intptr_t)lua_touserdata(L, i); - } - - push_value(L, ret); - return 1; -} - -int regex_init_flags(lua_State *L) { - lua_createtable(L, 0, 9); - - set_field(L, "ICASE", (void*)boost::u32regex::icase); - set_field(L, "NOSUB", (void*)boost::u32regex::nosubs); - set_field(L, "COLLATE", (void*)boost::u32regex::collate); - set_field(L, "NEWLINE_ALT", (void*)boost::u32regex::newline_alt); - set_field(L, "NO_MOD_M", (void*)boost::u32regex::no_mod_m); - set_field(L, "NO_MOD_S", (void*)boost::u32regex::no_mod_s); - set_field(L, "MOD_S", (void*)boost::u32regex::mod_s); - set_field(L, "MOD_X", (void*)boost::u32regex::mod_x); - set_field(L, "NO_EMPTY_SUBEXPRESSIONS", (void*)boost::u32regex::no_empty_expressions); - - return 1; -} - } extern "C" int luaopen_re_impl(lua_State *L) { - if (luaL_newmetatable(L, "aegisub.regex")) { - set_field(L, "__gc"); - lua_pop(L, 1); - } - - if (luaL_newmetatable(L, "aegisub.smatch")) { - set_field(L, "__gc"); - lua_pop(L, 1); - } - - lua_createtable(L, 0, 8); - set_field(L, "matches"); - set_field(L, "search"); - set_field(L, "match"); - set_field(L, "get_match"); - set_field(L, "replace"); - set_field(L, "compile"); - set_field(L, "process_flags"); - set_field(L, "init_flags"); + agi::lua::register_lib_table(L, {"agi_re_match", "u32regex"}, + "search", regex_search, + "match", regex_match, + "get_match", regex_get_match, + "replace", regex_replace, + "compile", regex_compile, + "get_flags", get_regex_flags, + "match_free", match_free, + "regex_free", regex_free); return 1; }