diff --git a/__tests__/TestSqlCommonRaw.ml b/__tests__/TestSqlCommonRaw.ml index 9912541..7104c72 100644 --- a/__tests__/TestSqlCommonRaw.ml +++ b/__tests__/TestSqlCommonRaw.ml @@ -155,4 +155,41 @@ describe "Raw SQL Query Test Sequence" (fun () -> |> finish ) ); -); + + testAsync "Rollback batch insert on duplicate key" (fun finish -> + let id_0 = string_of_int (Js.Math.random_int 0 (Js.Int.max - 1)) in + let code_0 = {j|unique-value-$id_0|j} in + let sql = "INSERT INTO test.simple (id, code) VALUES (?, ?)" in + let params = `Positional (Json.Encode.array Json.Encode.string [| id_0; code_0 |]) in + Sql.mutate conn ~sql ~params (fun res -> + match res with + | `Error e -> let _ = Js.log e in finish (fail "see log") + | `Mutation (_, _) -> + let id_1 = string_of_int (Js.Math.random_int 0 (Js.Int.max - 1)) in + let code_1 = {j|unique-value-$id_1|j} in + let batch_size = 1 in + let table = "test.simple" in + let columns = Belt.Array.map [|"id"; "code"|] Json.Encode.string in + (* order is important here *) + let rows = Belt.Array.map [| + Json.Encode.array Json.Encode.string [| id_1; code_1 |]; + Json.Encode.array Json.Encode.string [| id_0; code_0 |]; + |] (fun a -> a) in + Sql.mutate_batch conn ~batch_size ~table ~columns ~rows (fun res -> + match res with + | `Mutation (rows, id) -> let _ = Js.log3 "mutation should have failed" rows id in + finish (fail "see log") + | `Error _ -> Sql.query conn ~sql:{j|SELECT * FROM test.simple WHERE code='$(code_1)'|j} (fun res -> + match res with + | `Error e -> let _ = Js.log e in finish (fail "see log") + | `Select (rows, _) -> + rows + |> Expect.expect + |> Expect.toHaveLength 0 + |> finish + ) + ) + ); + ); +) + diff --git a/__tests__/TestUtil.ml b/__tests__/TestUtil.ml index 0af31a9..f10f1f8 100644 --- a/__tests__/TestUtil.ml +++ b/__tests__/TestUtil.ml @@ -1,6 +1,15 @@ +let env = Node.Process.process##env +let host = Belt.Option.getWithDefault (Js.Dict.get env "MYSQL_HOST") "localhost" +let port = int_of_string (Belt.Option.getWithDefault (Js.Dict.get env "MYSQL_PORT") "3306") +let user = Belt.Option.getWithDefault (Js.Dict.get env "MYSQL_USER") "root" +let password = Belt.Option.getWithDefault (Js.Dict.get env "MYSQL_PASSWORD") "password" +let database = Belt.Option.getWithDefault (Js.Dict.get env "MYSQL_DATABASE") "test" + let connect _ = MySql2.connect - ~host:"127.0.0.1" - ~port:3306 - ~user:"root" + ~host + ~port + ~user + ~password + ~database () diff --git a/src/SqlCommonBatchInsert.ml b/src/SqlCommonBatchInsert.ml index fe27006..8ec57fd 100644 --- a/src/SqlCommonBatchInsert.ml +++ b/src/SqlCommonBatchInsert.ml @@ -67,7 +67,12 @@ let insert execute ?batch_size ~table ~columns ~rows user_cb = | None -> 1000 | Some(s) -> s in - let fail = (fun e -> user_cb (`Error e)) in + let fail = (fun err -> rollback ~execute + ~fail:(fun err -> user_cb (`Error err)) + ~ok:(fun _ _ -> user_cb (`Error err)) + () + ) + in let complete = (fun count id -> let ok = (fun _ _ -> user_cb (`Mutation (count, id))) in