diff --git a/ext/tiny_tds/client.c b/ext/tiny_tds/client.c index 9ca2c868..830b69e1 100644 --- a/ext/tiny_tds/client.c +++ b/ext/tiny_tds/client.c @@ -24,7 +24,7 @@ VALUE opt_escape_regex, opt_escape_dblquote; // Lib Backend (Helpers) -static VALUE rb_tinytds_raise_error(DBPROCESS *dbproc, int cancel, char *error, char *source, int severity, int dberr, int oserr) { +VALUE rb_tinytds_raise_error(DBPROCESS *dbproc, int cancel, char *error, char *source, int severity, int dberr, int oserr) { GET_CLIENT_USERDATA(dbproc); if (cancel && !dbdead(dbproc) && userdata && !userdata->closed) { userdata->dbsqlok_sent = 1; @@ -92,14 +92,54 @@ int tinytds_err_handler(DBPROCESS *dbproc, int severity, int dberr, int oserr, c break; } } - rb_tinytds_raise_error(dbproc, cancel, dberrstr, source, severity, dberr, oserr); + /* + When in non-blocking mode we need to store the exception data to throw it + once the blocking call returns, otherwise we will segfault ruby since part + of the contract of the ruby non-blocking indicator is that you do not call + any of the ruby C API. + */ + if (userdata && userdata->nonblocking) { + /* + If we've already captured an error message, don't overwrite it. This is + here because FreeTDS sends a generic "General SQL Server error" message + that will overwrite the real message. This is not normally a problem + because a ruby exception is normally thrown and we bail before the + generic message can be sent. + */ + if (!userdata->nonblocking_error.is_set) { + userdata->nonblocking_error.cancel = cancel; + strcpy(userdata->nonblocking_error.error, dberrstr); + strcpy(userdata->nonblocking_error.source, source); + userdata->nonblocking_error.severity = severity; + userdata->nonblocking_error.dberr = dberr; + userdata->nonblocking_error.oserr = oserr; + userdata->nonblocking_error.is_set = 1; + } + } else { + rb_tinytds_raise_error(dbproc, cancel, dberrstr, source, severity, dberr, oserr); + } return return_value; } int tinytds_msg_handler(DBPROCESS *dbproc, DBINT msgno, int msgstate, int severity, char *msgtext, char *srvname, char *procname, int line) { static char *source = "message"; - if (severity > 10) - rb_tinytds_raise_error(dbproc, 1, msgtext, source, severity, msgno, msgstate); + GET_CLIENT_USERDATA(dbproc); + if (severity > 10) { + // See tinytds_err_handler() for info about why we do this + if (userdata && userdata->nonblocking) { + if (!userdata->nonblocking_error.is_set) { + userdata->nonblocking_error.cancel = 1; + strcpy(userdata->nonblocking_error.error, msgtext); + strcpy(userdata->nonblocking_error.source, source); + userdata->nonblocking_error.severity = severity; + userdata->nonblocking_error.dberr = msgno; + userdata->nonblocking_error.oserr = msgstate; + userdata->nonblocking_error.is_set = 1; + } + } else { + rb_tinytds_raise_error(dbproc, 1, msgtext, source, severity, msgno, msgstate); + } + } return 0; } @@ -108,6 +148,8 @@ static void rb_tinytds_client_reset_userdata(tinytds_client_userdata *userdata) userdata->dbsql_sent = 0; userdata->dbsqlok_sent = 0; userdata->dbcancel_sent = 0; + userdata->nonblocking = 0; + userdata->nonblocking_error.is_set = 0; } static void rb_tinytds_client_mark(void *ptr) { diff --git a/ext/tiny_tds/client.h b/ext/tiny_tds/client.h index 3dc99c8d..b53cce9c 100644 --- a/ext/tiny_tds/client.h +++ b/ext/tiny_tds/client.h @@ -4,6 +4,16 @@ void init_tinytds_client(); +typedef struct { + short int is_set; + int cancel; + char error[1024]; + char source[1024]; + int severity; + int dberr; + int oserr; +} tinytds_errordata; + typedef struct { short int closed; short int timing_out; @@ -11,6 +21,8 @@ typedef struct { short int dbsqlok_sent; RETCODE dbsqlok_retcode; short int dbcancel_sent; + short int nonblocking; + tinytds_errordata nonblocking_error; } tinytds_client_userdata; typedef struct { diff --git a/ext/tiny_tds/result.c b/ext/tiny_tds/result.c index a5adc5b9..1767df58 100644 --- a/ext/tiny_tds/result.c +++ b/ext/tiny_tds/result.c @@ -79,35 +79,73 @@ VALUE rb_tinytds_new_result_obj(tinytds_client_wrapper *cwrap) { #define NOGVL_DBCALL(_dbfunction, _client) ( \ (RETCODE)rb_thread_blocking_region( \ - (rb_blocking_function_t*)nogvl_ ## _dbfunction, _client, \ + (rb_blocking_function_t*)_dbfunction, _client, \ (rb_unblock_function_t*)dbcancel_ubf, _client ) \ ) +static void dbcancel_ubf(DBPROCESS *client) { + GET_CLIENT_USERDATA(client); + dbcancel(client); + userdata->dbcancel_sent = 1; + userdata->dbsql_sent = 0; +} + +static void nogvl_setup(DBPROCESS *client) { + GET_CLIENT_USERDATA(client); + userdata->nonblocking = 1; +} + +static void nogvl_cleanup(DBPROCESS *client) { + GET_CLIENT_USERDATA(client); + userdata->nonblocking = 0; + /* + Now that the blocking operation is done, we can finally throw any + exceptions based on errors from SQL Server. + */ + if (userdata->nonblocking_error.is_set) { + userdata->nonblocking_error.is_set = 0; + rb_tinytds_raise_error(client, + userdata->nonblocking_error.cancel, + &userdata->nonblocking_error.error, + &userdata->nonblocking_error.source, + userdata->nonblocking_error.severity, + userdata->nonblocking_error.dberr, + userdata->nonblocking_error.oserr); + } +} + static RETCODE nogvl_dbsqlok(DBPROCESS *client) { int retcode = FAIL; GET_CLIENT_USERDATA(client); - retcode = dbsqlok(client); + nogvl_setup(client); + retcode = NOGVL_DBCALL(dbsqlok, client); + nogvl_cleanup(client); userdata->dbsqlok_sent = 1; return retcode; } static RETCODE nogvl_dbsqlexec(DBPROCESS *client) { - return dbsqlexec(client); + int retcode = FAIL; + nogvl_setup(client); + retcode = NOGVL_DBCALL(dbsqlexec, client); + nogvl_cleanup(client); + return retcode; } static RETCODE nogvl_dbresults(DBPROCESS *client) { - return dbresults(client); + int retcode = FAIL; + nogvl_setup(client); + retcode = NOGVL_DBCALL(dbresults, client); + nogvl_cleanup(client); + return retcode; } static RETCODE nogvl_dbnextrow(DBPROCESS * client) { - return dbnextrow(client); -} - -static void dbcancel_ubf(DBPROCESS *client) { - GET_CLIENT_USERDATA(client); - dbcancel(client); - userdata->dbcancel_sent = 1; - userdata->dbsql_sent = 0; + int retcode = FAIL; + nogvl_setup(client); + retcode = NOGVL_DBCALL(dbnextrow, client); + nogvl_cleanup(client); + return retcode; } // Lib Backend (Helpers) @@ -118,7 +156,7 @@ static RETCODE rb_tinytds_result_dbresults_retcode(VALUE self) { RETCODE db_rc; ruby_rc = rb_ary_entry(rwrap->dbresults_retcodes, rwrap->number_of_results); if (NIL_P(ruby_rc)) { - db_rc = NOGVL_DBCALL(dbresults, rwrap->client); + db_rc = nogvl_dbresults(rwrap->client); ruby_rc = INT2FIX(db_rc); rb_ary_store(rwrap->dbresults_retcodes, rwrap->number_of_results, ruby_rc); } else { @@ -130,7 +168,7 @@ static RETCODE rb_tinytds_result_dbresults_retcode(VALUE self) { static RETCODE rb_tinytds_result_ok_helper(DBPROCESS *client) { GET_CLIENT_USERDATA(client); if (userdata->dbsqlok_sent == 0) { - userdata->dbsqlok_retcode = NOGVL_DBCALL(dbsqlok, client); + userdata->dbsqlok_retcode = nogvl_dbsqlok(client); } return userdata->dbsqlok_retcode; } @@ -373,7 +411,7 @@ static VALUE rb_tinytds_result_each(int argc, VALUE * argv, VALUE self) { /* Create rows for this result set. */ unsigned long rowi = 0; VALUE result = rb_ary_new(); - while (NOGVL_DBCALL(dbnextrow, rwrap->client) != NO_MORE_ROWS) { + while (nogvl_dbnextrow(rwrap->client) != NO_MORE_ROWS) { VALUE row = rb_tinytds_result_fetch_row(self, timezone, symbolize_keys, as_array); if (cache_rows) rb_ary_store(result, rowi, row); @@ -406,7 +444,7 @@ static VALUE rb_tinytds_result_each(int argc, VALUE * argv, VALUE self) { } else { // If we do not find results, side step the rb_tinytds_result_dbresults_retcode helper and // manually populate its memoized array while nullifing any memoized fields too before loop. - dbresults_rc = NOGVL_DBCALL(dbresults, rwrap->client); + dbresults_rc = nogvl_dbresults(rwrap->client); rb_ary_store(rwrap->dbresults_retcodes, rwrap->number_of_results, INT2FIX(dbresults_rc)); rb_ary_store(rwrap->fields_processed, rwrap->number_of_results, Qnil); } @@ -466,10 +504,10 @@ static VALUE rb_tinytds_result_insert(VALUE self) { rb_tinytds_result_cancel_helper(rwrap->client); VALUE identity = Qnil; dbcmd(rwrap->client, rwrap->cwrap->identity_insert_sql); - if (NOGVL_DBCALL(dbsqlexec, rwrap->client) != FAIL - && NOGVL_DBCALL(dbresults, rwrap->client) != FAIL + if (nogvl_dbsqlexec(rwrap->client) != FAIL + && nogvl_dbresults(rwrap->client) != FAIL && DBROWS(rwrap->client) != FAIL) { - while (NOGVL_DBCALL(dbnextrow, rwrap->client) != NO_MORE_ROWS) { + while (nogvl_dbnextrow(rwrap->client) != NO_MORE_ROWS) { int col = 1; BYTE *data = dbdata(rwrap->client, col); DBINT data_len = dbdatlen(rwrap->client, col); diff --git a/test/thread_test.rb b/test/thread_test.rb index 13c6070e..02bb39f9 100644 --- a/test/thread_test.rb +++ b/test/thread_test.rb @@ -38,6 +38,26 @@ class ThreadTest < TinyTds::TestCase assert x > mintime, "#{x} is not slower than #{mintime} seconds" end + it 'should not crash on error in parallel' do + threads = [] + @numthreads.times do |i| + start = Time.new + threads << Thread.new do + @pool.with do |client| + begin + result = client.execute "select dbname()" + result.each { |r| puts r } + rescue Exception => e + # We are throwing an error on purpose here since 0.6.1 would + # segfault on errors thrown in threads + end + end + end + end + threads.each { |t| t.join } + assert true + end + end end