diff --git a/CHANGES.md b/CHANGES.md index b179d9c173..25ddb1a34e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## Features/Changes * Compiler/wasm: omit code pointer from closures when not used (#2059, #2093) * Compiler/wasm: number unboxing (#2069, #2101) +* Compiler/wasm: specialization of number comparisons and bigarray operations (#1954) * Compiler/wasm: make the type of some Wasm primitives more precise (#2100) * Compiler: reference unboxing (#1958) diff --git a/compiler/lib-wasm/code_generation.ml b/compiler/lib-wasm/code_generation.ml index 96786d1a07..d9e3335d19 100644 --- a/compiler/lib-wasm/code_generation.ml +++ b/compiler/lib-wasm/code_generation.ml @@ -373,6 +373,7 @@ module Arith = struct (match e, e' with | W.Const (I32 n), W.Const (I32 n') when Int32.(n' < 31l) -> W.Const (I32 (Int32.shift_left n (Int32.to_int n'))) + | _, W.Const (I32 0l) -> e | _ -> W.BinOp (I32 Shl, e, e')) let ( lsr ) = binary (Shr U) diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index a2d9a10168..df63676c68 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -430,6 +430,38 @@ module Type = struct } ]) }) + + let int_array_type = + register_type "int_array" (fun () -> + return + { supertype = None + ; final = true + ; typ = W.Array { mut = true; typ = Value I32 } + }) + + let bigarray_type = + register_type "bigarray" (fun () -> + let* custom_operations = custom_operations_type in + let* int_array = int_array_type in + let* custom = custom_type in + return + { supertype = Some custom + ; final = true + ; typ = + W.Struct + [ { mut = false + ; typ = Value (Ref { nullable = false; typ = Type custom_operations }) + } + ; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) } + ; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) } + ; { mut = false + ; typ = Value (Ref { nullable = false; typ = Type int_array }) + } + ; { mut = false; typ = Packed I8 } + ; { mut = false; typ = Packed I8 } + ; { mut = false; typ = Packed I8 } + ] + }) end module Value = struct @@ -1373,6 +1405,235 @@ module Math = struct let exp2 x = power (return (W.Const (F64 2.))) x end +module Bigarray = struct + let dimension n a = + let* ty = Type.bigarray_type in + Memory.wasm_array_get + ~ty:Type.int_array_type + (Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 3) + (Arith.const (Int32.of_int n)) + + let get_at_offset ~(kind : Typing.Bigarray.kind) a i = + let name, (typ : Wasm_ast.value_type), size, box = + match kind with + | Float32 -> + ( "dv_get_f32" + , F32 + , 2 + , fun x -> + let* x = x in + return (W.F64PromoteF32 x) ) + | Float64 -> "dv_get_f64", F64, 3, Fun.id + | Int8_signed -> "dv_get_i8", I32, 0, Fun.id + | Int8_unsigned -> "dv_get_ui8", I32, 0, Fun.id + | Int16_signed -> "dv_get_i16", I32, 1, Fun.id + | Int16_unsigned -> "dv_get_ui16", I32, 1, Fun.id + | Int32 -> "dv_get_i32", I32, 2, Fun.id + | Nativeint -> "dv_get_i32", I32, 2, Fun.id + | Int64 -> "dv_get_i64", I64, 3, Fun.id + | Int -> "dv_get_i32", I32, 2, Fun.id + | Float16 -> + ( "dv_get_i16" + , I32 + , 1 + , fun x -> + let* conv = + register_import + ~name:"caml_float16_to_double" + (Fun { W.params = [ I32 ]; result = [ F64 ] }) + in + let* x = x in + return (W.Call (conv, [ x ])) ) + | Complex32 -> + ( "dv_get_f32" + , F32 + , 3 + , fun x -> + let* x = x in + return (W.F64PromoteF32 x) ) + | Complex64 -> "dv_get_f64", F64, 4, Fun.id + in + let* little_endian = + register_import + ~import_module:"bindings" + ~name:"littleEndian" + (Global { mut = false; typ = I32 }) + in + let* f = + register_import + ~import_module:"bindings" + ~name + (Fun + { W.params = + Ref { nullable = true; typ = Extern } + :: I32 + :: (if size = 0 then [] else [ I32 ]) + ; result = [ typ ] + }) + in + let* ty = Type.bigarray_type in + let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2 in + let* ofs = Arith.(i lsl const (Int32.of_int size)) in + match kind with + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Float16 -> + box + (return + (W.Call + (f, ta :: ofs :: (if size = 0 then [] else [ W.GlobalGet little_endian ])))) + | Complex32 | Complex64 -> + let delta = Int32.shift_left 1l (size - 1) in + let* ofs' = Arith.(return ofs + const delta) in + let* x = box (return (W.Call (f, [ ta; ofs; W.GlobalGet little_endian ]))) in + let* y = box (return (W.Call (f, [ ta; ofs'; W.GlobalGet little_endian ]))) in + let* ty = Type.float_array_type in + return (W.ArrayNewFixed (ty, [ x; y ])) + + let set_at_offset ~kind a i v = + let name, (typ : Wasm_ast.value_type), size, unbox = + match (kind : Typing.Bigarray.kind) with + | Float32 -> + ( "dv_set_f32" + , F32 + , 2 + , fun x -> + let* x = x in + return (W.F32DemoteF64 x) ) + | Float64 -> "dv_set_f64", F64, 3, Fun.id + | Int8_signed | Int8_unsigned -> "dv_set_i8", I32, 0, Fun.id + | Int16_signed | Int16_unsigned -> "dv_set_i16", I32, 1, Fun.id + | Int32 -> "dv_set_i32", I32, 2, Fun.id + | Nativeint -> "dv_set_i32", I32, 2, Fun.id + | Int64 -> "dv_set_i64", I64, 3, Fun.id + | Int -> "dv_set_i32", I32, 2, Fun.id + | Float16 -> + ( "dv_set_i16" + , I32 + , 1 + , fun x -> + let* conv = + register_import + ~name:"caml_double_to_float16" + (Fun { W.params = [ F64 ]; result = [ I32 ] }) + in + let* x = x in + return (W.Call (conv, [ x ])) ) + | Complex32 -> + ( "dv_set_f32" + , F32 + , 3 + , fun x -> + let* x = x in + return (W.F32DemoteF64 x) ) + | Complex64 -> "dv_set_f64", F64, 4, Fun.id + in + let* ty = Type.bigarray_type in + let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2 in + let* ofs = Arith.(i lsl const (Int32.of_int size)) in + let* little_endian = + register_import + ~import_module:"bindings" + ~name:"littleEndian" + (Global { mut = false; typ = I32 }) + in + let* f = + register_import + ~import_module:"bindings" + ~name + (Fun + { W.params = + Ref { nullable = true; typ = Extern } + :: I32 + :: typ + :: (if size = 0 then [] else [ I32 ]) + ; result = [] + }) + in + match kind with + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Float16 -> + let* v = unbox v in + instr + (W.CallInstr + ( f + , ta :: ofs :: v :: (if size = 0 then [] else [ W.GlobalGet little_endian ]) + )) + | Complex32 | Complex64 -> + let delta = Int32.shift_left 1l (size - 1) in + let* ofs' = Arith.(return ofs + const delta) in + let ty = Type.float_array_type in + let* x = unbox (Memory.wasm_array_get ~ty v (Arith.const 0l)) in + let* () = instr (W.CallInstr (f, [ ta; ofs; x; W.GlobalGet little_endian ])) in + let* y = unbox (Memory.wasm_array_get ~ty v (Arith.const 1l)) in + instr (W.CallInstr (f, [ ta; ofs'; y; W.GlobalGet little_endian ])) + + let offset ~bound_error_index ~(layout : Typing.Bigarray.layout) ta ~indices = + let l = + List.mapi + ~f:(fun pos i -> + let i = + match layout with + | C -> i + | Fortran -> Arith.(i - const 1l) + in + let i' = Code.Var.fresh () in + let dim = Code.Var.fresh () in + ( (let* () = store ~typ:I32 i' i in + let* () = store ~typ:I32 dim (dimension pos ta) in + let* cond = Arith.uge (load i') (load dim) in + instr (W.Br_if (bound_error_index, cond))) + , i' + , dim )) + indices + in + let l = + match layout with + | C -> l + | Fortran -> List.rev l + in + match l with + | (instrs, i', _) :: rem -> + List.fold_left + ~f:(fun (instrs, ofs) (instrs', i', dim) -> + let ofs' = Code.Var.fresh () in + ( (let* () = instrs in + let* () = instrs' in + store ~typ:I32 ofs' Arith.((ofs * load dim) + load i')) + , load ofs' )) + ~init:(instrs, load i') + rem + | [] -> return (), Arith.const 0l + + let get ~bound_error_index ~kind ~layout ta ~indices = + let instrs, ofs = offset ~bound_error_index ~layout ta ~indices in + seq instrs (get_at_offset ~kind ta ofs) + + let set ~bound_error_index ~kind ~layout ta ~indices v = + let instrs, ofs = offset ~bound_error_index ~layout ta ~indices in + seq + (let* () = instrs in + set_at_offset ~kind ta ofs v) + Value.unit +end + module JavaScript = struct let anyref = W.Ref { nullable = true; typ = Any } diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index b84189d7cb..5c51a05321 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -115,13 +115,6 @@ module Generate (Target : Target_sig.S) = struct ; "caml_erf_float", (`Pure, [ Float ], Float) ; "caml_erfc_float", (`Pure, [ Float ], Float) ; "caml_float_compare", (`Pure, [ Float; Float ], Int) - ; "caml_greaterthan", (`Mutator, [ Value; Value ], Int) - ; "caml_greaterequal", (`Mutator, [ Value; Value ], Int) - ; "caml_lessthan", (`Mutator, [ Value; Value ], Int) - ; "caml_lessequal", (`Mutator, [ Value; Value ], Int) - ; "caml_equal", (`Mutator, [ Value; Value ], Int) - ; "caml_notequal", (`Mutator, [ Value; Value ], Int) - ; "caml_compare", (`Mutator, [ Value; Value ], Int) ]; h @@ -226,7 +219,8 @@ module Generate (Target : Target_sig.S) = struct (if negate then Value.phys_neq else Value.phys_eq) (transl_prim_arg ctx ~typ:Top x) (transl_prim_arg ctx ~typ:Top y) - | (Int _ | Number _ | Tuple _), _ | _, (Int _ | Number _ | Tuple _) -> + | (Int _ | Number _ | Tuple _ | Bigarray _), _ + | _, (Int _ | Number _ | Tuple _ | Bigarray _) -> (* Only Top may contain JavaScript values *) (if negate then Value.phys_neq else Value.phys_eq) (transl_prim_arg ctx ~typ:Top x) @@ -290,6 +284,39 @@ module Generate (Target : Target_sig.S) = struct (transl_prim_arg ctx ?typ:tz z) | _ -> invalid_arity name l ~expected:3) + let register_comparison name cmp_int cmp_boxed_int cmp_float = + register_prim name `Mutator (fun ctx _ l -> + match l with + | [ x; y ] -> ( + match get_type ctx x, get_type ctx y with + | Int _, Int _ -> cmp_int ctx x y + | Number (Int32, _), Number (Int32, _) -> + let x = transl_prim_arg ctx ~typ:(Number (Int32, Unboxed)) x in + let y = transl_prim_arg ctx ~typ:(Number (Int32, Unboxed)) y in + int32_bin_op cmp_boxed_int x y + | Number (Nativeint, _), Number (Nativeint, _) -> + let x = transl_prim_arg ctx ~typ:(Number (Nativeint, Unboxed)) x in + let y = transl_prim_arg ctx ~typ:(Number (Nativeint, Unboxed)) y in + nativeint_bin_op cmp_boxed_int x y + | Number (Int64, _), Number (Int64, _) -> + let x = transl_prim_arg ctx ~typ:(Number (Int64, Unboxed)) x in + let y = transl_prim_arg ctx ~typ:(Number (Int64, Unboxed)) y in + int64_bin_op cmp_boxed_int x y + | Number (Float, _), Number (Float, _) -> + let x = transl_prim_arg ctx ~typ:(Number (Float, Unboxed)) x in + let y = transl_prim_arg ctx ~typ:(Number (Float, Unboxed)) y in + float_bin_op cmp_float x y + | _ -> + let* f = + register_import + ~name + (Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] }) + in + let* x = transl_prim_arg ctx x in + let* y = transl_prim_arg ctx y in + return (W.Call (f, [ x; y ]))) + | _ -> invalid_arity name l ~expected:2) + let () = register_bin_prim "caml_floatarray_unsafe_get" @@ -1092,7 +1119,215 @@ module Generate (Target : Target_sig.S) = struct ~ty:(Int Normalized) (fun i j -> Arith.((j < i) - (i < j))); register_prim "%js_array" `Pure (fun ctx _ l -> - Memory.allocate ~tag:0 (expression_list (fun x -> transl_prim_arg ctx x) l)) + Memory.allocate ~tag:0 (expression_list (fun x -> transl_prim_arg ctx x) l)); + register_comparison + "caml_greaterthan" + (fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x < y)) x y) + (Gt S) + Gt; + register_comparison + "caml_greaterequal" + (fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x <= y)) x y) + (Ge S) + Ge; + register_comparison + "caml_lessthan" + (fun ctx x y -> translate_int_comparison ctx Arith.( < ) x y) + (Lt S) + Lt; + register_comparison + "caml_lessequal" + (fun ctx x y -> translate_int_comparison ctx Arith.( <= ) x y) + (Le S) + Le; + register_comparison + "caml_equal" + (fun ctx x y -> translate_int_equality ctx ~negate:false x y) + Eq + Eq; + register_comparison + "caml_notequal" + (fun ctx x y -> translate_int_equality ctx ~negate:true x y) + Ne + Ne; + register_prim "caml_compare" `Mutator (fun ctx _ l -> + match l with + | [ x; y ] -> ( + match get_type ctx x, get_type ctx y with + | Int _, Int _ -> + let x' = transl_prim_arg ctx ~typ:(Int Normalized) x in + let y' = transl_prim_arg ctx ~typ:(Int Normalized) y in + Arith.((y' < x') - (x' < y')) + | Number (Int32, _), Number (Int32, _) + | Number (Nativeint, _), Number (Nativeint, _) -> + let* f = + register_import + ~name:"caml_int32_compare" + (Fun { W.params = [ I32; I32 ]; result = [ I32 ] }) + in + let* x' = transl_prim_arg ctx ~typ:(Number (Int32, Unboxed)) x in + let* y' = transl_prim_arg ctx ~typ:(Number (Int32, Unboxed)) y in + return (W.Call (f, [ x'; y' ])) + | Number (Int64, _), Number (Int64, _) -> + let* f = + register_import + ~name:"caml_int64_compare" + (Fun { W.params = [ I64; I64 ]; result = [ I32 ] }) + in + let* x' = transl_prim_arg ctx ~typ:(Number (Int64, Unboxed)) x in + let* y' = transl_prim_arg ctx ~typ:(Number (Int64, Unboxed)) y in + return (W.Call (f, [ x'; y' ])) + | Number (Float, _), Number (Float, _) -> + let* f = + register_import + ~name:"caml_float_compare" + (Fun { W.params = [ F64; F64 ]; result = [ I32 ] }) + in + let* x' = transl_prim_arg ctx ~typ:(Number (Float, Unboxed)) x in + let* y' = transl_prim_arg ctx ~typ:(Number (Float, Unboxed)) y in + return (W.Call (f, [ x'; y' ])) + | _ -> + let* f = + register_import + ~name:"caml_compare" + (Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] }) + in + let* x' = transl_prim_arg ctx x in + let* y' = transl_prim_arg ctx y in + return (W.Call (f, [ x'; y' ]))) + | _ -> invalid_arity "caml_compare" l ~expected:2); + let bigarray_generic_access ~ctx ta indices = + match + ( get_type ctx ta + , match indices with + | Pv indices -> Some (indices, ctx.global_flow_info.info_defs.(Var.idx indices)) + | Pc _ -> None ) + with + | Bigarray { kind; layout }, Some (indices, Expr (Block (_, l, _, _))) -> + Some + ( kind + , layout + , List.mapi + ~f:(fun i _ -> + Value.int_val + (Memory.array_get (load indices) (Arith.const (Int32.of_int (i + 1))))) + (Array.to_list l) ) + | _, None | _, Some (_, (Expr _ | Phi _)) -> None + in + let caml_ba_get ~ctx ~context ~kind ~layout ta indices = + let ta' = transl_prim_arg ctx ta in + Bigarray.get + ~bound_error_index:(label_index context bound_error_pc) + ~kind + ~layout + ta' + ~indices + in + let caml_ba_get_n ~ctx ~context ta indices = + match get_type ctx ta with + | Bigarray { kind; layout } -> + let indices = + List.map ~f:(fun i -> transl_prim_arg ctx ~typ:(Int Normalized) i) indices + in + caml_ba_get ~ctx ~context ~kind ~layout ta indices + | _ -> + let n = List.length indices in + let* f = + register_import + ~name:(Printf.sprintf "caml_ba_get_%d" n) + (Fun (Type.primitive_type (n + 1))) + in + let* ta' = transl_prim_arg ctx ta in + let* indices' = expression_list (transl_prim_arg ctx) indices in + return (W.Call (f, ta' :: indices')) + in + register_prim "caml_ba_get_1" `Mutator (fun ctx context l -> + match l with + | [ ta; i ] -> caml_ba_get_n ~ctx ~context ta [ i ] + | _ -> invalid_arity "caml_ba_get_1" l ~expected:2); + register_prim "caml_ba_get_2" `Mutator (fun ctx context l -> + match l with + | [ ta; i; j ] -> caml_ba_get_n ~ctx ~context ta [ i; j ] + | _ -> invalid_arity "caml_ba_get_2" l ~expected:3); + register_prim "caml_ba_get_3" `Mutator (fun ctx context l -> + match l with + | [ ta; i; j; k ] -> caml_ba_get_n ~ctx ~context ta [ i; j; k ] + | _ -> invalid_arity "caml_ba_get_3" l ~expected:4); + register_prim "caml_ba_get_generic" `Mutator (fun ctx context l -> + match l with + | [ ta; indices ] -> ( + match bigarray_generic_access ~ctx ta indices with + | Some (kind, layout, indices) -> + caml_ba_get ~ctx ~context ~kind ~layout ta indices + | _ -> + let* f = + register_import + ~name:"caml_ba_get_generic" + (Fun (Type.primitive_type 2)) + in + let* ta' = transl_prim_arg ctx ta in + let* indices' = transl_prim_arg ctx indices in + return (W.Call (f, [ ta'; indices' ]))) + | _ -> invalid_arity "caml_ba_get_generic" l ~expected:2); + let caml_ba_set ~ctx ~context ~kind ~layout ta indices v = + let ta' = transl_prim_arg ctx ta in + let v' = transl_prim_arg ctx ~typ:(Typing.bigarray_element_type kind) v in + Bigarray.set + ~bound_error_index:(label_index context bound_error_pc) + ~kind + ~layout + ta' + ~indices + v' + in + let caml_ba_set_n ~ctx ~context ta indices v = + match get_type ctx ta with + | Bigarray { kind; layout } -> + let indices = + List.map ~f:(fun i -> transl_prim_arg ctx ~typ:(Int Normalized) i) indices + in + caml_ba_set ~ctx ~context ~kind ~layout ta indices v + | _ -> + let n = List.length indices in + let* f = + register_import + ~name:(Printf.sprintf "caml_ba_set_%d" n) + (Fun (Type.primitive_type (n + 2))) + in + let* ta' = transl_prim_arg ctx ta in + let* indices' = expression_list (transl_prim_arg ctx) indices in + let* v' = transl_prim_arg ctx v in + return (W.Call (f, ta' :: (indices' @ [ v' ]))) + in + register_prim "caml_ba_set_1" `Mutator (fun ctx context l -> + match l with + | [ ta; i; v ] -> caml_ba_set_n ~ctx ~context ta [ i ] v + | _ -> invalid_arity "caml_ba_set_1" l ~expected:3); + register_prim "caml_ba_set_2" `Mutator (fun ctx context l -> + match l with + | [ ta; i; j; v ] -> caml_ba_set_n ~ctx ~context ta [ i; j ] v + | _ -> invalid_arity "caml_ba_set_2" l ~expected:4); + register_prim "caml_ba_set_3" `Mutator (fun ctx context l -> + match l with + | [ ta; i; j; k; v ] -> caml_ba_set_n ~ctx ~context ta [ i; j; k ] v + | _ -> invalid_arity "caml_ba_set_3" l ~expected:5); + register_prim "caml_ba_set_generic" `Mutator (fun ctx context l -> + match l with + | [ ta; indices; v ] -> ( + match bigarray_generic_access ~ctx ta indices with + | Some (kind, layout, indices) -> + caml_ba_set ~ctx ~context ~kind ~layout ta indices v + | _ -> + let* f = + register_import + ~name:"caml_ba_set_generic" + (Fun (Type.primitive_type 3)) + in + let* ta' = transl_prim_arg ctx ta in + let* indices' = transl_prim_arg ctx indices in + let* v' = transl_prim_arg ctx v in + return (W.Call (f, [ ta'; indices'; v' ]))) + | _ -> invalid_arity "caml_ba_set_generic" l ~expected:3) let unboxed_type ty : W.value_type option = match ty with @@ -1414,7 +1649,15 @@ module Generate (Target : Target_sig.S) = struct | "caml_bytes_set" | "caml_check_bound" | "caml_check_bound_gen" - | "caml_check_bound_float" ) + | "caml_check_bound_float" + | "caml_ba_get_1" + | "caml_ba_get_2" + | "caml_ba_get_3" + | "caml_ba_get_generic" + | "caml_ba_set_1" + | "caml_ba_set_2" + | "caml_ba_set_3" + | "caml_ba_set_generic" ) , _ ) ) -> fst n, true | Let ( _ diff --git a/compiler/lib-wasm/target_sig.ml b/compiler/lib-wasm/target_sig.ml index f661a7fbe6..a0fc5e8ce9 100644 --- a/compiler/lib-wasm/target_sig.ml +++ b/compiler/lib-wasm/target_sig.ml @@ -254,6 +254,25 @@ module type S = sig val round : expression -> expression end + module Bigarray : sig + val get : + bound_error_index:int + -> kind:Typing.Bigarray.kind + -> layout:Typing.Bigarray.layout + -> expression + -> indices:expression list + -> expression + + val set : + bound_error_index:int + -> kind:Typing.Bigarray.kind + -> layout:Typing.Bigarray.layout + -> expression + -> indices:expression list + -> expression + -> expression + end + val internal_primitives : (string * Primitive.kind diff --git a/compiler/lib-wasm/typing.ml b/compiler/lib-wasm/typing.ml index 913e49d4ac..8785c88dfc 100644 --- a/compiler/lib-wasm/typing.ml +++ b/compiler/lib-wasm/typing.ml @@ -44,6 +44,82 @@ type boxed_status = | Boxed | Unboxed +module Bigarray = struct + type kind = + | Float16 + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Complex32 + | Complex64 + + type layout = + | C + | Fortran + + type t = + { kind : kind + ; layout : layout + } + + let make ~kind ~layout = + { kind = + (match kind with + | 0 -> Float32 + | 1 -> Float64 + | 2 -> Int8_signed + | 3 -> Int8_unsigned + | 4 -> Int16_signed + | 5 -> Int16_unsigned + | 6 -> Int32 + | 7 -> Int64 + | 8 -> Int + | 9 -> Nativeint + | 10 -> Complex32 + | 11 -> Complex64 + | 12 -> Int8_unsigned + | 13 -> Float16 + | _ -> assert false) + ; layout = + (match layout with + | 0 -> C + | 1 -> Fortran + | _ -> assert false) + } + + let print f { kind; layout } = + Format.fprintf + f + "bigarray{%s,%s}" + (match kind with + | Float32 -> "float32" + | Float64 -> "float64" + | Int8_signed -> "sint8" + | Int8_unsigned -> "uint8" + | Int16_signed -> "sint16" + | Int16_unsigned -> "uint16" + | Int32 -> "int32" + | Int64 -> "int64" + | Int -> "int" + | Nativeint -> "nativeint" + | Complex32 -> "complex32" + | Complex64 -> "complex64" + | Float16 -> "float16") + (match layout with + | C -> "C" + | Fortran -> "Fortran") + + let equal { kind; layout } { kind = kind'; layout = layout' } = + phys_equal kind kind' && phys_equal layout layout' +end + type typ = | Top | Int of Integer.kind @@ -52,6 +128,7 @@ type typ = (** This value is a block or an integer; if it's an integer, an overapproximation of the possible values of each of its fields is given by the array of types *) + | Bigarray of Bigarray.t | Bot module Domain = struct @@ -81,8 +158,9 @@ module Domain = struct if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i))) | Int _, Tuple _ -> t' | Tuple _, Int _ -> t + | Bigarray b, Bigarray b' when Bigarray.equal b b' -> t | Top, _ | _, Top -> Top - | (Int _ | Number _ | Tuple _), _ -> Top + | (Int _ | Number _ | Tuple _ | Bigarray _), _ -> Top let join_set ?(others = false) f s = if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot @@ -94,7 +172,8 @@ module Domain = struct | Number (t, b), Number (t', b') -> Poly.equal t t' && Poly.equal b b' | Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:equal t t' - | (Top | Tuple _ | Int _ | Number _ | Bot), _ -> false + | Bigarray b, Bigarray b' -> Bigarray.equal b b' + | (Top | Tuple _ | Int _ | Number _ | Bigarray _ | Bot), _ -> false let bot = Bot @@ -102,12 +181,12 @@ module Domain = struct let rec depth t = match t with - | Top | Bot | Number _ | Int _ -> 0 + | Top | Bot | Number _ | Int _ | Bigarray _ -> 0 | Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0 let rec truncate depth t = match t with - | Top | Bot | Number _ | Int _ -> t + | Top | Bot | Number _ | Int _ | Bigarray _ -> t | Tuple l -> if depth = 0 then Top @@ -145,6 +224,7 @@ module Domain = struct (match b with | Boxed -> "boxed" | Unboxed -> "unboxed") + | Bigarray b -> Bigarray.print f b | Tuple t -> Format.fprintf f @@ -160,7 +240,18 @@ let update_deps st { blocks; _ } = List.iter block.body ~f:(fun i -> match i with | Let (x, Block (_, lst, _, _)) -> Array.iter ~f:(fun y -> add_dep st x y) lst - | Let (x, Prim (Extern ("%int_and" | "%int_or" | "%int_xor"), lst)) -> + | Let + ( x + , Prim + ( Extern + ( "%int_and" + | "%int_or" + | "%int_xor" + | "caml_ba_get_1" + | "caml_ba_get_2" + | "caml_ba_get_3" + | "caml_ba_get_generic" ) + , lst ) ) -> (* The return type of these primitives depend on the input type *) List.iter ~f:(fun p -> @@ -206,7 +297,23 @@ let arg_type ~approx arg = | Pc c -> constant_type c | Pv x -> Var.Tbl.get approx x -let prim_type ~approx prim args = +let bigarray_element_type (kind : Bigarray.kind) = + match kind with + | Float16 | Float32 | Float64 -> Number (Float, Unboxed) + | Int8_signed | Int8_unsigned | Int16_signed | Int16_unsigned -> Int Normalized + | Int -> Int Unnormalized + | Int32 -> Number (Int32, Unboxed) + | Int64 -> Number (Int64, Unboxed) + | Nativeint -> Number (Nativeint, Unboxed) + | Complex32 | Complex64 -> Tuple [| Number (Float, Boxed); Number (Float, Boxed) |] + +let bigarray_type ~approx ba = + match arg_type ~approx ba with + | Bot -> Bot + | Bigarray { kind; _ } -> bigarray_element_type kind + | _ -> Top + +let prim_type ~st ~approx prim args = match prim with | "%int_add" | "%int_sub" | "%int_mul" | "%direct_int_mul" | "%int_lsl" | "%int_neg" -> Int Unnormalized @@ -366,6 +473,25 @@ let prim_type ~approx prim args = | "caml_nativeint_to_int" -> Int Unnormalized | "caml_nativeint_of_int" -> Number (Nativeint, Unboxed) | "caml_int_compare" -> Int Normalized + | "caml_ba_create" -> ( + match args with + | [ Pc (Int kind); Pc (Int layout); _ ] -> + Bigarray + (Bigarray.make + ~kind:(Targetint.to_int_exn kind) + ~layout:(Targetint.to_int_exn layout)) + | _ -> Top) + | "caml_ba_get_1" | "caml_ba_get_2" | "caml_ba_get_3" -> ( + match args with + | ba :: _ -> bigarray_type ~approx ba + | [] -> Top) + | "caml_ba_get_generic" -> ( + match args with + | ba :: Pv indices :: _ -> ( + match st.global_flow_state.defs.(Var.idx indices) with + | Expr (Block _) -> bigarray_type ~approx ba + | _ -> Top) + | [] | [ _ ] | _ :: Pc _ :: _ -> Top) | _ -> Top let propagate st approx x : Domain.t = @@ -424,7 +550,7 @@ let propagate st approx x : Domain.t = | Top -> Top) | Prim (Array_get, _) -> Top | Prim ((Vectlength | Not | IsInt | Eq | Neq | Lt | Le | Ult), _) -> Int Normalized - | Prim (Extern prim, args) -> prim_type ~approx prim args + | Prim (Extern prim, args) -> prim_type ~st ~approx prim args | Special _ -> Top | Apply { f; args; _ } -> ( match Var.Tbl.get st.global_flow_info.info_approximation f with @@ -437,7 +563,35 @@ let propagate st approx x : Domain.t = when List.length args = List.length params -> let res = Domain.join_set - (fun y -> Var.Tbl.get approx y) + (fun y -> + match st.global_flow_state.defs.(Var.idx y) with + | Expr + (Prim (Extern "caml_ba_create", [ Pv kind; Pv layout; _ ])) + -> ( + let m = + List.fold_left2 + ~f:(fun m p a -> Var.Map.add p a m) + ~init:Var.Map.empty + params + args + in + try + match + ( st.global_flow_state.defs.(Var.idx + (Var.Map.find kind m)) + , st.global_flow_state.defs.(Var.idx + (Var.Map.find layout m)) + ) + with + | ( Expr (Constant (Int kind)) + , Expr (Constant (Int layout)) ) -> + Bigarray + (Bigarray.make + ~kind:(Targetint.to_int_exn kind) + ~layout:(Targetint.to_int_exn layout)) + | _ -> raise Not_found + with Not_found -> Var.Tbl.get approx y) + | _ -> Var.Tbl.get approx y) (Var.Map.find g st.global_flow_state.return_values) in if can_unbox_return_value st.fun_info g then res else Domain.box res @@ -590,6 +744,43 @@ let primitives_with_unboxed_parameters = ]; h +let type_specialized_primitive types global_flow_state name args = + match name with + | "caml_greaterthan" + | "caml_greaterequal" + | "caml_lessthan" + | "caml_lessequal" + | "caml_equal" + | "caml_notequal" + | "caml_compare" -> ( + match List.map ~f:(arg_type ~approx:types) args with + | [ Int _; Int _ ] + | [ Number (Int32, _); Number (Int32, _) ] + | [ Number (Int64, _); Number (Int64, _) ] + | [ Number (Nativeint, _); Number (Nativeint, _) ] + | [ Number (Float, _); Number (Float, _) ] -> true + | _ -> false) + | "caml_ba_get_1" + | "caml_ba_get_2" + | "caml_ba_get_3" + | "caml_ba_set_1" + | "caml_ba_set_2" + | "caml_ba_set_3" -> ( + match args with + | Pv x :: _ -> ( + match Var.Tbl.get types x with + | Bigarray _ -> true + | _ -> false) + | _ -> false) + | "caml_ba_get_generic" | "caml_ba_set_generic" -> ( + match args with + | Pv x :: Pv indices :: _ -> ( + match Var.Tbl.get types x, global_flow_state.defs.(Var.idx indices) with + | Bigarray _, Expr (Block _) -> true + | _ -> false) + | _ -> false) + | _ -> false + let box_numbers p st types = (* We box numbers eagerly if the boxed value is ever used. *) let should_box = Var.ISet.empty () in @@ -614,7 +805,7 @@ let box_numbers p st types = Var.Set.iter box s) | Expr _ -> () | Phi { known; _ } -> Var.Set.iter box known) - | Number (_, Boxed) | Int _ | Tuple _ | Bot -> ()) + | Number (_, Boxed) | Int _ | Tuple _ | Bigarray _ | Bot -> ()) in Code.fold_closures p @@ -636,7 +827,11 @@ let box_numbers p st types = then List.iter ~f:box args | Block (tag, lst, _, _) -> if tag <> 254 then Array.iter ~f:box lst | Prim (Extern s, args) -> - if not (String.Hashtbl.mem primitives_with_unboxed_parameters s) + if + not + (String.Hashtbl.mem primitives_with_unboxed_parameters s + || type_specialized_primitive types st.global_flow_state s args + ) then List.iter ~f:(fun a -> @@ -667,6 +862,13 @@ let box_numbers p st types = ()) () +let print_opt types global_flow_state f e = + match e with + | Prim (Extern name, args) + when type_specialized_primitive types global_flow_state name args -> + Format.fprintf f " OPT" + | _ -> () + type t = { types : typ Var.Tbl.t ; return_types : typ Var.Hashtbl.t @@ -696,7 +898,13 @@ let f ~global_flow_state ~global_flow_info ~fun_info ~deadcode_sentinal p = Format.err_formatter (fun _ i -> match i with - | Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get types x) + | Instr (Let (x, e)) -> + Format.asprintf + "{%a}%a" + Domain.print + (Var.Tbl.get types x) + (print_opt types global_flow_state) + e | _ -> "") p); let return_types = Var.Hashtbl.create 128 in diff --git a/compiler/lib-wasm/typing.mli b/compiler/lib-wasm/typing.mli index be9c2aff32..5ea4e7da51 100644 --- a/compiler/lib-wasm/typing.mli +++ b/compiler/lib-wasm/typing.mli @@ -15,17 +15,46 @@ type boxed_status = | Boxed | Unboxed +module Bigarray : sig + type kind = + | Float16 + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Complex32 + | Complex64 + + type layout = + | C + | Fortran + + type t = + { kind : kind + ; layout : layout + } +end + type typ = | Top | Int of Integer.kind | Number of boxed_number * boxed_status | Tuple of typ array + | Bigarray of Bigarray.t | Bot val constant_type : Code.constant -> typ val can_unbox_parameters : Call_graph_analysis.t -> Code.Var.t -> bool +val bigarray_element_type : Bigarray.kind -> typ + type t val var_type : t -> Code.Var.t -> typ diff --git a/compiler/lib/inline.ml b/compiler/lib/inline.ml index 365b7445c6..0412d0d19c 100644 --- a/compiler/lib/inline.ml +++ b/compiler/lib/inline.ml @@ -237,17 +237,39 @@ let sum ~context f pc = blocks 0 -let rec block_size ~recurse ~context { branch; body; _ } = +let rec block_size ~inline_comparisons ~recurse ~context { branch; body; _ } = List.fold_left ~f:(fun n i -> match i with | Event _ -> n + | Let + ( _ + , Prim + ( Extern + ( "caml_lessthan" + | "caml_lessequal" + | "caml_greaterthan" + | "caml_greaterequal" + | "caml_equal" + | "caml_notequal" ) + , _ ) ) + when inline_comparisons -> + (* Bias toward inlining functions containing polymorphic + comparisons, such as min and max, in the hope that + polymorphic comparisons can be specialized. *) + n - 1 | Let (f, Closure (_, (pc, _), _)) -> if recurse then match Var.Map.find f context.env with - | exception Not_found -> size ~recurse ~context pc + n + 1 - | info -> cache ~info info.full_size (size ~recurse:true ~context) + n + 1 + | exception Not_found -> size ~inline_comparisons ~recurse ~context pc + n + 1 + | info -> + cache + ~info + info.full_size + (size ~inline_comparisons ~recurse:true ~context) + + n + + 1 else n + 1 | _ -> n + 1) ~init: @@ -257,13 +279,21 @@ let rec block_size ~recurse ~context { branch; body; _ } = | _ -> 0) body -and size ~recurse ~context = sum ~context (block_size ~recurse ~context) +and size ~inline_comparisons ~recurse ~context = + sum ~context (block_size ~inline_comparisons ~recurse ~context) (** Size of the function body *) -let body_size ~context info = cache ~info info.body_size (size ~recurse:false ~context) +let body_size ~context info = + let inline_comparisons = + match Config.target () with + | `JavaScript -> false + | `Wasm -> true + in + cache ~info info.body_size (size ~inline_comparisons ~recurse:false ~context) (** Size of the function, including the size of the closures it contains *) -let full_size ~context info = cache ~info info.full_size (size ~recurse:true ~context) +let full_size ~context info = + cache ~info info.full_size (size ~inline_comparisons:false ~recurse:true ~context) let closure_count_uncached ~context = sum ~context (fun { body; _ } -> diff --git a/runtime/wasm/bigarray.wat b/runtime/wasm/bigarray.wat index 35fd72c452..0fe421874e 100644 --- a/runtime/wasm/bigarray.wat +++ b/runtime/wasm/bigarray.wat @@ -184,7 +184,8 @@ (field $ba_kind i8) ;; kind (field $ba_layout i8)))) ;; layout - (func $double_to_float16 (param $f f64) (result i32) + (func $double_to_float16 (export "caml_double_to_float16") + (param $f f64) (result i32) (local $x i32) (local $sign i32) (local $o i32) (local.set $x (i32.reinterpret_f32 (f32.demote_f64 (local.get $f)))) (local.set $sign (i32.and (local.get $x) (i32.const 0x80000000))) @@ -214,7 +215,8 @@ (i32.const 13))))))) (i32.or (local.get $o) (i32.shr_u (local.get $sign) (i32.const 16)))) - (func $float16_to_double (param $d i32) (result f64) + (func $float16_to_double (export "caml_float16_to_double") + (param $d i32) (result f64) (local $f f32) (local.set $f (f32.mul