Convert the re module over to the LuaJIT ffi

This commit is contained in:
Thomas Goyne 2014-07-19 21:48:58 -07:00
parent a01a84fb4f
commit 74a215f642
2 changed files with 178 additions and 183 deletions

View file

@ -17,32 +17,79 @@ next = next
select = select
type = type
-- Get the boost::regex binding
bit = require 'bit'
ffi = require 'ffi'
ffi_util = require 'aegisub.ffi'
ffi.cdef[[
typedef struct agi_re_flag {
const char *name;
int value;
} agi_re_flag;
]]
regex_flag = ffi.typeof 'agi_re_flag'
-- Get the boost::eegex binding
regex = require 'aegisub.__re_impl'
-- Wrappers to convert returned values from C types to Lua types
search = (re, str, start) ->
return unless start <= str\len()
res = regex.search re, str, str\len(), start
return unless res != nil
first, last = res[0], res[1]
ffi.C.free res
first, last
replace = (re, replacement, str, max_count) ->
ffi_util.string regex.replace re, replacement, str, str\len(), max_count
match = (re, str, start) ->
assert start <= str\len()
m = regex.match re, str, str\len(), start
return unless m != nil
ffi.gc m, regex.match_free
get_match = (m, idx) ->
res = regex.get_match m, idx
return unless res != nil
res[0], res[1] -- Result buffer is owned by match so no need to free
err_buff = ffi.new 'char *[1]'
compile = (pattern, flags) ->
err_buff[0] = nil
re = regex.compile pattern, flags, err_buff
if err_buff[0] != nil
return ffi.string err_buff[0]
ffi.gc re, regex.regex_free
-- Return the first n elements from ...
select_first = (n, a, ...) ->
if n == 0 then return
a, select_first n - 1, ...
-- Bitwise-or together regex flags passed as arguments to a function
process_flags = (...) ->
flags = 0
for i = 1, select '#', ...
v = select i, ...
if not ffi.istype regex_flag, v
error 'Flags must follow all non-flag arguments', 3
flags = bit.bor flags, v.value
flags
-- Extract the flags from ..., bitwise OR them together, and move them to the
-- front of ...
unpack_args = (...) ->
userdata_start = nil
flags_start = nil
for i = 1, select '#', ...
v = select i, ...
if type(v) == 'userdata'
userdata_start = i
if ffi.istype regex_flag, v
flags_start = i
break
return 0, ... unless userdata_start
flags = regex.process_flags select userdata_start, ...
if type(flags) == 'string'
error(flags, 3)
flags, select_first userdata_start - 1, ...
return 0, ... unless flags_start
process_flags(select flags_start, ...), select_first flags_start - 1, ...
-- Typecheck a variable and throw an error if it fails
check_arg = (arg, expected_type, argn, func_name, level) ->
@ -108,20 +155,19 @@ class RegEx
new: (@_regex, @_level) =>
start = 1
gsplit: (str, skip_empty, max_split) =>
@_check_self!
check_arg str, 'string', 2, 'gsplit', @_level
if not max_split or max_split <= 0 then max_split = str\len()
start = 1
start = 0
prev = 1
do_split = () ->
if not str or str\len() == 0 then return
local first, last
if max_split > 0
first, last = regex.search @_regex, str, start
first, last = search @_regex, str, start
if not first or first > str\len()
ret = str\sub prev, str\len()
@ -131,7 +177,7 @@ class RegEx
ret = str\sub prev, first - 1
prev = last + 1
start = 1 + if start >= last then start else last
start = if start >= last then start + 1 else last
if skip_empty and ret\len() == 0
do_split()
@ -150,12 +196,12 @@ class RegEx
@_check_self!
check_arg str, 'string', 2, 'gfind', @_level
start = 1
start = 0
->
first, last = regex.search(@_regex, str, start)
first, last = search(@_regex, str, start)
return unless first
start = if last >= start then last + 1 else start + 1
start = if last > start then last else start + 1
str\sub(first, last), first, last
find: (str) =>
@ -176,7 +222,7 @@ class RegEx
if type(repl) == 'function'
do_replace_fun @, repl, str, max_count
elseif type(repl) == 'string'
regex.replace @_regex, repl, str, max_count
replace @_regex, repl, str, max_count
else
error "Argument 2 to sub should be a string or function, is '#{type(repl)}' (#{repl})", @_level
@ -185,11 +231,11 @@ class RegEx
check_arg str, 'string', 2, 'gmatch', @_level
start = if start then start - 1 else 0
match = regex.match @_regex, str, start
i = 1
m = match @_regex, str, start
i = 0
->
return unless match
first, last = regex.get_match match, i
return unless m
first, last = get_match m, i
return unless first
i += 1
@ -213,7 +259,7 @@ real_compile = (pattern, level, flags, stored_level) ->
if pattern == ''
error 'Regular expression must not be empty', level + 1
re = regex.compile pattern, flags
re = compile pattern, flags
if type(re) == 'string'
error regex, level + 1
@ -225,25 +271,31 @@ invoke = (str, pattern, fn, flags, ...) ->
compiled_regex[fn](compiled_regex, str, ...)
-- Generate a static version of a method with arg type checking
gen_wrapper = (impl_name) ->
(str, pattern, ...) ->
gen_wrapper = (impl_name) -> (str, pattern, ...) ->
check_arg str, 'string', 1, impl_name, 2
check_arg pattern, 'string', 2, impl_name, 2
invoke str, pattern, impl_name, unpack_args ...
-- And now at last the actual public API
re = regex.init_flags(re)
re.compile = (pattern, ...) ->
do
re = {
compile: (pattern, ...) ->
check_arg pattern, 'string', 1, 'compile', 2
real_compile pattern, 2, regex.process_flags(...), 2
real_compile pattern, 2, process_flags(...), 2
re.split = gen_wrapper 'split'
re.gsplit = gen_wrapper 'gsplit'
re.find = gen_wrapper 'find'
re.gfind = gen_wrapper 'gfind'
re.match = gen_wrapper 'match'
re.gmatch = gen_wrapper 'gmatch'
re.sub = gen_wrapper 'sub'
split: gen_wrapper 'split'
gsplit: gen_wrapper 'gsplit'
find: gen_wrapper 'find'
gfind: gen_wrapper 'gfind'
match: gen_wrapper 'match'
gmatch: gen_wrapper 'gmatch'
sub: gen_wrapper 'sub'
}
re
i = 0
flags = regex.get_flags()
while flags[i].name != nil
re[ffi.string flags[i].name] = flags[i]
i += 1
re

View file

@ -14,85 +14,72 @@
//
// Aegisub Project http://www.aegisub.org/
#include "libaegisub/lua/utils.h"
#include "libaegisub/lua/ffi.h"
#include "libaegisub/make_unique.h"
#include <boost/regex/icu.hpp>
#include <lauxlib.h>
using boost::u32regex;
namespace {
// A cmatch with a match range attached to it so that we can return a pointer to
// an int pair without an extra heap allocation each time (LuaJIT can't compile
// ffi calls which return aggregates by value)
struct agi_re_match {
boost::cmatch m;
int range[2];
};
struct agi_re_flag {
const char *name;
int value;
};
}
namespace agi {
AGI_DEFINE_TYPE_NAME(u32regex);
AGI_DEFINE_TYPE_NAME(agi_re_match);
AGI_DEFINE_TYPE_NAME(agi_re_flag);
}
namespace {
using namespace agi::lua;
boost::u32regex& get_regex(lua_State *L) {
return get<boost::u32regex>(L, 1, "aegisub.regex");
using match = agi_re_match;
bool search(u32regex& re, const char *str, size_t len, int start, boost::cmatch& result) {
return u32regex_search(str + start, str + len, result, re,
start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default);
}
boost::smatch& get_smatch(lua_State *L) {
return get<boost::smatch>(L, 1, "aegisub.smatch");
match *regex_match(u32regex& re, const char *str, size_t len, int start) {
auto result = agi::make_unique<match>();
if (!search(re, str, len, start, result->m))
return nullptr;
return result.release();
}
int regex_matches(lua_State *L) {
push_value(L, u32regex_match(check_string(L, 2), get_regex(L)));
return 1;
int *regex_get_match(match& match, size_t idx) {
if (idx > match.m.size() || !match.m[idx].matched)
return nullptr;
match.range[0] = std::distance(match.m.prefix().first, match.m[idx].first + 1);
match.range[1] = std::distance(match.m.prefix().first, match.m[idx].second);
return match.range;
}
int regex_match(lua_State *L) {
auto re = get_regex(L);
std::string str = check_string(L, 2);
int start = lua_tointeger(L, 3);
int *regex_search(u32regex& re, const char *str, size_t len, size_t start) {
boost::cmatch result;
if (!search(re, str, len, start, result))
return nullptr;
auto result = make<boost::smatch>(L, "aegisub.smatch");
if (!u32regex_search(str.cbegin() + start, str.cend(), *result, re,
start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default))
{
lua_pop(L, 1);
lua_pushnil(L);
}
return 1;
auto ret = static_cast<int *>(malloc(sizeof(int) * 2));
ret[0] = start + result.position() + 1;
ret[1] = start + result.position() + result.length();
return ret;
}
int regex_get_match(lua_State *L) {
auto& match = get_smatch(L);
auto idx = check_uint(L, 2) - 1;
if (idx > match.size() || !match[idx].matched) {
lua_pushnil(L);
return 1;
}
push_value(L, distance(match.prefix().first, match[idx].first + 1));
push_value(L, distance(match.prefix().first, match[idx].second));
return 2;
}
int regex_search(lua_State *L) {
auto& re = get_regex(L);
auto str = check_string(L, 2);
auto start = check_uint(L, 3) - 1;
argcheck(L, start <= str.size(), 3, "out of bounds");
boost::smatch result;
if (!u32regex_search(str.cbegin() + start, str.cend(), result, re,
start > 0 ? boost::match_prev_avail | boost::match_not_bob : boost::match_default))
{
lua_pushnil(L);
return 1;
}
push_value(L, start + result.position() + 1);
push_value(L, start + result.position() + result.length());
return 2;
}
int regex_replace(lua_State *L) {
auto& re = get_regex(L);
const auto replacement = check_string(L, 2);
const auto str = check_string(L, 3);
int max_count = check_int(L, 4);
char *regex_replace(u32regex& re, const char *replacement, const char *str, size_t len, int max_count) {
// Can't just use regex_replace here since it can only do one or infinite replacements
auto match = boost::u32regex_iterator<std::string::const_iterator>(begin(str), end(str), re);
auto end_it = boost::u32regex_iterator<std::string::const_iterator>();
auto match = boost::u32regex_iterator<const char *>(str, str + len, re);
auto end_it = boost::u32regex_iterator<const char *>();
auto suffix = begin(str);
auto suffix = str;
std::string ret;
auto out = back_inserter(ret);
@ -104,95 +91,51 @@ int regex_replace(lua_State *L) {
--max_count;
}
copy(suffix, end(str), out);
push_value(L, ret);
return 1;
ret += suffix;
return strdup(ret.c_str());
}
int regex_compile(lua_State *L) {
auto pattern(check_string(L, 1));
int flags = check_int(L, 2);
auto re = make<boost::u32regex>(L, "aegisub.regex");
u32regex *regex_compile(const char *pattern, int flags, char **err) {
auto re = agi::make_unique<u32regex>();
try {
*re = boost::make_u32regex(pattern, boost::u32regex::perl | flags);
return re.release();
}
catch (std::exception const& e) {
lua_pop(L, 1);
push_value(L, e.what());
return 1;
// Do the actual triggering of the error in the Lua code as that code
// can report the original call site
*err = strdup(e.what());
return nullptr;
}
return 1;
}
int regex_gc(lua_State *L) {
using boost::u32regex;
get_regex(L).~u32regex();
return 0;
void regex_free(u32regex *re) { delete re; }
void match_free(match *m) { delete m; }
const agi_re_flag *get_regex_flags() {
static const agi_re_flag flags[] = {
{"ICASE", boost::u32regex::icase},
{"NOSUB", boost::u32regex::nosubs},
{"COLLATE", boost::u32regex::collate},
{"NEWLINE_ALT", boost::u32regex::newline_alt},
{"NO_MOD_M", boost::u32regex::no_mod_m},
{"NO_MOD_S", boost::u32regex::no_mod_s},
{"MOD_S", boost::u32regex::mod_s},
{"MOD_X", boost::u32regex::mod_x},
{"NO_EMPTY_SUBEXPRESSIONS", boost::u32regex::no_empty_expressions},
{nullptr, 0}
};
return flags;
}
int smatch_gc(lua_State *L) {
using boost::smatch;
get_smatch(L).~smatch();
return 0;
}
int regex_process_flags(lua_State *L) {
int ret = 0;
int nargs = lua_gettop(L);
for (int i = 1; i <= nargs; ++i) {
if (!lua_islightuserdata(L, i)) {
push_value(L, "Flags must follow all non-flag arguments");
return 1;
}
ret |= (int)(intptr_t)lua_touserdata(L, i);
}
push_value(L, ret);
return 1;
}
int regex_init_flags(lua_State *L) {
lua_createtable(L, 0, 9);
set_field(L, "ICASE", (void*)boost::u32regex::icase);
set_field(L, "NOSUB", (void*)boost::u32regex::nosubs);
set_field(L, "COLLATE", (void*)boost::u32regex::collate);
set_field(L, "NEWLINE_ALT", (void*)boost::u32regex::newline_alt);
set_field(L, "NO_MOD_M", (void*)boost::u32regex::no_mod_m);
set_field(L, "NO_MOD_S", (void*)boost::u32regex::no_mod_s);
set_field(L, "MOD_S", (void*)boost::u32regex::mod_s);
set_field(L, "MOD_X", (void*)boost::u32regex::mod_x);
set_field(L, "NO_EMPTY_SUBEXPRESSIONS", (void*)boost::u32regex::no_empty_expressions);
return 1;
}
}
extern "C" int luaopen_re_impl(lua_State *L) {
if (luaL_newmetatable(L, "aegisub.regex")) {
set_field<regex_gc>(L, "__gc");
lua_pop(L, 1);
}
if (luaL_newmetatable(L, "aegisub.smatch")) {
set_field<smatch_gc>(L, "__gc");
lua_pop(L, 1);
}
lua_createtable(L, 0, 8);
set_field<regex_matches>(L, "matches");
set_field<regex_search>(L, "search");
set_field<regex_match>(L, "match");
set_field<regex_get_match>(L, "get_match");
set_field<regex_replace>(L, "replace");
set_field<regex_compile>(L, "compile");
set_field<regex_process_flags>(L, "process_flags");
set_field<regex_init_flags>(L, "init_flags");
agi::lua::register_lib_table(L, {"agi_re_match", "u32regex"},
"search", regex_search,
"match", regex_match,
"get_match", regex_get_match,
"replace", regex_replace,
"compile", regex_compile,
"get_flags", get_regex_flags,
"match_free", match_free,
"regex_free", regex_free);
return 1;
}