diff --git a/aegisub/automation/include/re.lua b/aegisub/automation/include/re.lua new file mode 100644 index 000000000..7461eb378 --- /dev/null +++ b/aegisub/automation/include/re.lua @@ -0,0 +1,303 @@ +-- Copyright (c) 2012, Thomas Goyne +-- +-- Permission to use, copy, modify, and distribute this software for any +-- purpose with or without fee is hereby granted, provided that the above +-- copyright notice and this permission notice appear in all copies. +-- +-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +-- ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +-- OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +-- +-- $Id$ + +-- Get the wxRegex binding +local regex = aegisub.__init_regex() + +-- Compiled regular expression type protoype +local re_proto = {} +local re_proto_mt = { __index = re_proto } + +-- Convert an iterator to an array +local function to_table(...) + local arr = {} + local i = 1 + for v in ... do + arr[i] = v + i = i + 1 + end + return arr +end + +-- Return the first n elements from ... +local function select_first(n, a, ...) + if n == 0 then return end + return a, select_first(n - 1, ...) +end + +-- Extract the flags from ..., bitwise OR them together, and move them to the +-- front of ... +local function unpack_args(...) + local n = select('#', ...) + local userdata_start = nil + for i = 1, n do + local v = select(i, ...) + if type(v) == "userdata" then + userdata_start = i + break + end + end + + if not userdata_start then + return 0, ... + end + + flags = regex.process_flags(select(userdata_start, ...)) + if type(flags) == "string" then + error(flags, 3) + end + + return flags, select_first(userdata_start - 1, ...) +end + +-- Verify that a valid value was passed for self +local function check_self(self) + if getmetatable(self) ~= re_proto_mt then + error("re method called with invalid self. You probably used . when : is needed.", 3) + end +end + +-- Typecheck a variable and throw an error if it fails +local function check_arg(arg, expected_type, argn, func_name, level) + if type(arg) ~= expected_type then + error( + string.format("Argument %d to %s should be a '%s', is '%s' (%s)", + argn, func_name, expected_type, type(arg), tostring(arg)), + level + 1) + end +end + +function re_proto.gsplit(self, str, skip_empty, max_split) + check_self(self) + check_arg(str, "string", 2, "gsplit", self._level) + if not max_split or max_split <= 0 then max_split = str:len() end + + local function do_split() + if not str or str:len() == 0 then + return nil + end + + if max_split == 0 or not regex.matches(self._regex, str) then + local ret = str + str = nil + return ret + end + + local first, last = regex.get_match(self._regex, str, 0) + local ret = str:sub(1, first - 1) + str = str:sub(last + 1) + + if skip_empty and ret:len() == 0 then + return do_split() + else + max_split = max_split - 1 + return ret + end + end + + return do_split +end + +function re_proto.split(self, str, skip_empty, max_split) + check_self(self) + check_arg(str, "string", 2, "split", self._level) + return to_table(self:gsplit(str, skip_empty, max_split)) +end + +function re_proto.gfind(self, str) + check_self(self) + check_arg(str, "string", 2, "gfind", self._level) + + local offset = 0 + return function() + local has_matches = regex.matches(self._regex, str) + if not has_matches then return end + + local first, last = regex.get_match(self._regex, str, 0) + local ret = str:sub(first, last) + str = str:sub(last + 1) + + last = last + offset + offset = offset + first + return ret, offset, last + end +end + +function re_proto.find(self, str) + check_self(self) + check_arg(str, "string", 2, "find", self._level) + + local i = 1 + local ret = {} + for s, f, l in self:gfind(str) do + ret[i] = { + str = s, + first = f, + last = l + } + i = i + 1 + end + return ret +end + +-- Replace a match with the value returned from func when passed the match +local function replace_match(match, func, str, last, acc) + if last < match.last then + acc[#acc + 1] = str:sub(last, match.first - 1) + end + + local ret = func(match.str, match.first, match.last) + if type(ret) == "string" then + acc[#acc + 1] = ret + else + -- If it didn't return a string just leave the old value + acc[#acc + 1] = match.str + end + + return match.last + 1 +end + +-- Replace all matches from a single iteration of the regexp +local function do_single_replace_fun(re, func, str, acc) + local matches = re:match(str) + + -- No more matches so just return what we have so far + if not matches then + return str + end + + -- One match means no capturing groups, so pass the entire thing to + -- the replace function + if #matches == 1 then + local rest = replace_match(matches[1], func, str, 1, acc) + return str:sub(rest), true + end + + -- Multiple matches means there were capture groups, so skip the first one + -- and pass the rest to the replace function + local last = 1 + for i = 2, #matches do + last = replace_match(matches[i], func, str, last, acc) + end + + return str:sub(last), true +end + +local function do_replace_fun(re, func, str, max) + local acc = {} + local i + for i = 1, max do + str, continue = do_single_replace_fun(re, func, str, acc) + if not continue then max = i end + end + return table.concat(acc, "") .. str, max +end + +function re_proto.sub(self, str, repl, count) + check_self(self) + check_arg(str, "string", 2, "sub", self._level) + if count ~= nil then + check_arg(count, "number", 4, "sub", self._level) + end + + if not count or count == 0 then count = str:len() end + + if type(repl) == "function" then + return do_replace_fun(self, repl, str, count) + elseif type(repl) == "string" then + return regex.replace(self._regex, repl, str, count) + else + error( + string.format("Argument 2 to sub should be a string or function, is '%s' (%s)", + type(repl), tostring(repl)), + self._level) + end +end + +function re_proto.gmatch(self, str) + check_self(self) + check_arg(str, "string", 2, "gmatch", self._level) + + local match_count = regex.match_count(self._regex, str) + local i = 0 + return function() + if i == match_count then return end + i = i + 1 + local first, last = regex.get_match(self._regex, str, i - 1) + return { + str = str:sub(first, last), + first = first, + last = last + } + end +end + +function re_proto.match(self, str) + check_self(self) + check_arg(str, "string", 2, "match", self._level) + + local ret = to_table(self:gmatch(str)) + -- Return nil rather than a empty table so that if re.match(...) works + if next(ret) == nil then return end + return ret +end + +-- Create a wxRegExp object from a pattern, flags, and error depth +local function real_compile(pattern, level, flags, stored_level) + local regex = regex.compile(pattern, flags) + if not regex then + error("Bad syntax in regular expression", level + 1) + end + return setmetatable({ + _regex = regex, + _level = stored_level or level + 1 + }, + re_proto_mt) +end + +-- Compile a pattern then invoke a method on it +local function invoke(str, pattern, fn, flags, ...) + local comp = real_compile(pattern, 3, flags) + return comp[fn](comp, str, ...) +end + +-- Generate a static version of a method with arg type checking +local function gen_wrapper(impl_name) + return function(str, pattern, ...) + check_arg(str, "string", 1, impl_name, 2) + check_arg(pattern, "string", 2, impl_name, 2) + return invoke(str, pattern, impl_name, unpack_args(...)) + end +end + +-- And now at last the actual public API +local re = regex.init_flags(re) + +function re.compile(pattern, ...) + check_arg(pattern, "string", 1, "compile", 2) + return real_compile(pattern, 2, regex.process_flags(...), 2) +end + +re.split = gen_wrapper("split") +re.gsplit = gen_wrapper("gsplit") +re.find = gen_wrapper("find") +re.gfind = gen_wrapper("gfind") +re.match = gen_wrapper("match") +re.gmatch = gen_wrapper("gmatch") +re.sub = gen_wrapper("sub") + +_G.re = re +return _G.re diff --git a/aegisub/src/auto4_lua.cpp b/aegisub/src/auto4_lua.cpp index 8c99f9169..2d689e241 100644 --- a/aegisub/src/auto4_lua.cpp +++ b/aegisub/src/auto4_lua.cpp @@ -89,6 +89,11 @@ namespace { lua_pushinteger(L, n); } + inline void push_value(lua_State *L, void *p) + { + lua_pushlightuserdata(L, p); + } + template inline void set_field(lua_State *L, const char *name, T value) { @@ -134,6 +139,125 @@ namespace { lua_pop(L, 1); return c; } + + inline wxRegEx *get_regex(lua_State *L) + { + return static_cast(luaL_checkudata(L, 1, "aegisub.regex")); + } + + int regex_matches(lua_State *L) + { + lua_pushboolean(L, get_regex(L)->Matches(check_wxstring(L, 2))); + return 1; + } + + int regex_match_count(lua_State *L) + { + wxRegEx *re = get_regex(L); + if (re->Matches(check_wxstring(L, 2))) + lua_pushinteger(L, re->GetMatchCount()); + else + lua_pushinteger(L, 0); + return 1; + } + + size_t utf8_len(wxString const& w) + { +#if wxUSE_UNICODE_UTF8 + return w.utf8_length(); +#else + return w.utf8_str().length(); +#endif + } + + int regex_get_match(lua_State *L) + { + wxString str(check_wxstring(L, 2)); + size_t start, len; + get_regex(L)->GetMatch(&start, &len, luaL_checkinteger(L, 3)); + lua_pushinteger(L, utf8_len(str.Left(start)) + 1); + lua_pushinteger(L, utf8_len(str.Left(start + len))); + return 2; + } + + int regex_replace(lua_State *L) + { + wxString str(check_wxstring(L, 3)); + int reps = get_regex(L)->Replace(&str, check_wxstring(L, 2), luaL_checkinteger(L, 4)); + lua_pushstring(L, str.utf8_str()); + lua_pushinteger(L, reps); + return 2; + } + + int regex_compile(lua_State *L) + { + wxString pattern(check_wxstring(L, 1)); + int flags = luaL_checkinteger(L, 2); + wxRegEx *re = static_cast(lua_newuserdata(L, sizeof(wxRegEx))); + new(re) wxRegEx(pattern, wxRE_ADVANCED | flags); + + luaL_getmetatable(L, "aegisub.regex"); + lua_setmetatable(L, -2); + + // return nil and handle the error in lua as it's a bit easier to + // report the actual call site from there + if (!re->IsValid()) { + lua_pop(L, 1); + lua_pushnil(L); + } + + return 1; + } + + int regex_gc(lua_State *L) { + get_regex(L)->~wxRegEx(); + return 0; + } + + int regex_process_flags(lua_State *L) + { + int ret = 0; + int nargs = lua_gettop(L); + for (int i = 1; i <= nargs; ++i) { + if (!lua_islightuserdata(L, i)) { + lua_pushstring(L, "Flags must follow all non-flag arguments"); + return 1; + } + ret |= (int)lua_touserdata(L, i); + } + + lua_pushinteger(L, ret); + return 1; + } + + int regex_init_flags(lua_State *L) + { + lua_newtable(L); + + set_field(L, "ICASE", (void*)wxRE_ICASE); + set_field(L, "NOSUB", (void*)wxRE_NOSUB); + set_field(L, "NEWLINE", (void*)wxRE_NEWLINE); + + return 1; + } + + int regex_init(lua_State *L) + { + if (luaL_newmetatable(L, "aegisub.regex")) { + set_field(L, "__gc", regex_gc); + lua_pop(L, 1); + } + + lua_newtable(L); + set_field(L, "matches", regex_matches); + set_field(L, "match_count", regex_match_count); + set_field(L, "get_match", regex_get_match); + set_field(L, "replace", regex_replace); + set_field(L, "compile", regex_compile); + set_field(L, "process_flags", regex_process_flags); + set_field(L, "init_flags", regex_init_flags); + return 1; + } } // LuaStackcheck @@ -265,6 +389,7 @@ namespace Automation4 { set_field(L, "decode_path", LuaDecodePath); set_field(L, "cancel", LuaCancel); set_field(L, "lua_automation_version", 4); + set_field(L, "__init_regex", regex_init); // store aegisub table to globals lua_settable(L, LUA_GLOBALSINDEX);