diff --git a/ext/cgi/escape/escape.c b/ext/cgi/escape/escape.c index 495ad83..4681308 100644 --- a/ext/cgi/escape/escape.c +++ b/ext/cgi/escape/escape.c @@ -33,20 +33,20 @@ preserve_original_state(VALUE orig, VALUE dest) } static inline long -escaped_length(VALUE str) +escaped_length(VALUE str, long escape_max_len) { const long len = RSTRING_LEN(str); - if (len >= LONG_MAX / HTML_ESCAPE_MAX_LEN) { - ruby_malloc_size_overflow(len, HTML_ESCAPE_MAX_LEN); + if (len >= LONG_MAX / escape_max_len) { + ruby_malloc_size_overflow(len, escape_max_len); } - return len * HTML_ESCAPE_MAX_LEN; + return len * escape_max_len; } static VALUE optimized_escape_html(VALUE str) { VALUE vbuf; - char *buf = ALLOCV_N(char, vbuf, escaped_length(str)); + char *buf = ALLOCV_N(char, vbuf, escaped_length(str, HTML_ESCAPE_MAX_LEN)); const char *cstr = RSTRING_PTR(str); const char *end = cstr + RSTRING_LEN(str); @@ -75,6 +75,86 @@ optimized_escape_html(VALUE str) return escaped; } +struct build_escape_table_args { + long max_escape_length; + VALUE *escape_table; +}; + +static int +build_escape_table_i(VALUE key, VALUE val, VALUE _arg) +{ + struct build_escape_table_args *arg = (struct build_escape_table_args *)_arg; + Check_Type(key, T_STRING); + Check_Type(val, T_STRING); + + if (RSTRING_LEN(key) != 1) { + rb_raise(rb_eArgError, "CGI.escapeHTML keys must be single ASCII characters"); + } + unsigned char c = RSTRING_PTR(key)[0]; + + if (c >= 0x80) { + rb_raise(rb_eArgError, "CGI.escapeHTML keys must be single ASCII characters"); + } + + arg->escape_table[c] = val; + + long escape_length = RSTRING_LEN(val); + if (arg->max_escape_length < escape_length) { + arg->max_escape_length = escape_length; + } + + return ST_CONTINUE; +} + +static long +build_escape_table(VALUE *escape_table, VALUE rb_escape_table) +{ + struct build_escape_table_args arg = { + .escape_table = escape_table, + }; + rb_hash_foreach(rb_escape_table, build_escape_table_i, (VALUE)&arg); + return arg.max_escape_length; +} + +static VALUE +dynamic_escape_html(VALUE str, VALUE rb_escape_table) +{ + VALUE escape_table[UCHAR_MAX+1] = {0}; + long max_escape_length = build_escape_table(escape_table, rb_escape_table); + + VALUE vbuf; + char *buf = ALLOCV_N(char, vbuf, escaped_length(str, max_escape_length)); + const char *cstr = RSTRING_PTR(str); + const char *end = cstr + RSTRING_LEN(str); + + char *dest = buf; + while (cstr < end) { + const unsigned char c = *cstr++; + VALUE escaped_character = escape_table[c]; + if (escaped_character) { + const char *ptr; + long len; + RSTRING_GETMEM(escaped_character, ptr, len); + MEMCPY(dest, ptr, char, len); + dest += len; + } + else { + *dest++ = c; + } + } + + VALUE escaped; + if (RSTRING_LEN(str) < (dest - buf)) { + escaped = rb_str_new(buf, dest - buf); + preserve_original_state(str, escaped); + } + else { + escaped = rb_str_dup(str); + } + ALLOCV_END(vbuf); + return escaped; +} + static VALUE optimized_unescape_html(VALUE str) { @@ -331,15 +411,24 @@ optimized_unescape(VALUE str, VALUE encoding, int unescape_plus) * */ static VALUE -cgiesc_escape_html(VALUE self, VALUE str) +cgiesc_escape_html(int argc, VALUE *argv, VALUE self) { + rb_check_arity(argc, 1, 2); + + VALUE str = argv[0]; StringValue(str); if (rb_enc_str_asciicompat_p(str)) { - return optimized_escape_html(str); + if (argc == 1) { + return optimized_escape_html(str); + } + else { + Check_Type(argv[1], T_HASH); + return dynamic_escape_html(str, argv[1]); + } } else { - return rb_call_super(1, &str); + return rb_call_super(argc, argv); } } @@ -474,7 +563,7 @@ InitVM_escape(void) rb_cCGI = rb_define_class("CGI", rb_cObject); rb_mEscape = rb_define_module_under(rb_cCGI, "Escape"); rb_mUtil = rb_define_module_under(rb_cCGI, "Util"); - rb_define_method(rb_mEscape, "escapeHTML", cgiesc_escape_html, 1); + rb_define_method(rb_mEscape, "escapeHTML", cgiesc_escape_html, -1); rb_define_method(rb_mEscape, "unescapeHTML", cgiesc_unescape_html, 1); rb_define_method(rb_mEscape, "escapeURIComponent", cgiesc_escape_uri_component, 1); rb_define_alias(rb_mEscape, "escape_uri_component", "escapeURIComponent"); diff --git a/lib/cgi/util.rb b/lib/cgi/util.rb index 5f12eae..8ff40de 100644 --- a/lib/cgi/util.rb +++ b/lib/cgi/util.rb @@ -74,7 +74,7 @@ def unescapeURIComponent(string, encoding = @@accept_charset) # Escape special characters in HTML, namely '&\"<> # CGI.escapeHTML('Usage: foo "bar" ') # # => "Usage: foo "bar" <baz>" - def escapeHTML(string) + def escapeHTML(string, escape_table = nil) enc = string.encoding unless enc.ascii_compatible? if enc.dummy? @@ -82,13 +82,27 @@ def escapeHTML(string) enc = Encoding::Converter.asciicompat_encoding(enc) string = enc ? string.encode(enc) : string.b end - table = Hash[TABLE_FOR_ESCAPE_HTML__.map {|pair|pair.map {|s|s.encode(enc)}}] - string = string.gsub(/#{"['&\"<>]".encode(enc)}/, table) + if escape_table + table = Hash[escape_table.map {|pair| pair.map {|s|s.encode(enc)}}] + pattern = "[".encode(enc) + escape_table.each_key do |key| + pattern << Regexp.escape(key).encode(enc) + end + pattern << "]".encode(enc) + string = string.gsub(/#{pattern}/, table) + else + table = Hash[TABLE_FOR_ESCAPE_HTML__.map {|pair|pair.map {|s|s.encode(enc)}}] + string = string.gsub(/#{"['&\"<>]".encode(enc)}/, table) + end string.encode!(origenc) if origenc string else string = string.b - string.gsub!(/['&\"<>]/, TABLE_FOR_ESCAPE_HTML__) + if escape_table + string.gsub!(p Regexp.union(escape_table.keys), escape_table) + else + string.gsub!(/['&\"<>]/, TABLE_FOR_ESCAPE_HTML__) + end string.force_encoding(enc) end end diff --git a/test/cgi/test_cgi_util.rb b/test/cgi/test_cgi_util.rb index bff77f7..4fb7435 100644 --- a/test/cgi/test_cgi_util.rb +++ b/test/cgi/test_cgi_util.rb @@ -135,6 +135,19 @@ def test_cgi_escapeHTML assert_equal("'&"><", CGI.escapeHTML("'&\"><")) end + def test_dynamic_cgi_escapeHTML + assert_equal("'&\"><", CGI.escapeHTML("'&\"><", { "<" => "<" })) + assert_equal("'\\u0026\"\\u003e\\u003c", CGI.escapeHTML("'&\"><", { + ">" => '\u003e', + "<" => '\u003c', + "&" => '\u0026', + })) + + assert_raise(ArgumentError) { CGI.escapeHTML(" ", { "12" => "<" }) } + assert_raise(ArgumentError) { CGI.escapeHTML(" ", { "€" => "<" }) } + assert_raise(ArgumentError) { CGI.escapeHTML(" ", { "" => "<" }) } + end + def test_cgi_escape_html_duplicated orig = "Ruby".dup.force_encoding("US-ASCII") str = CGI.escapeHTML(orig) @@ -215,6 +228,17 @@ def test_cgi_unescapeHTML_following_invalid_numeric define_method("test_cgi_unescapeHTML:#{enc.name}") do assert_equal(unescaped, CGI.unescapeHTML(escaped)) end + + define_method("test_cgi_dynamic_unescapeHTML:#{enc.name}") do + table = { + "'" => ''', + '&' => '&', + '"' => '"', + '<' => '<', + '>' => '>', + } + assert_equal(escaped, CGI.escapeHTML(unescaped, table)) + end end end