Skip to content

ONNX Runtime improvements #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
416 changes: 163 additions & 253 deletions package-lock.json

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"typegen": "tsc --build",
"dev": "webpack serve --no-client-overlay",
"build": "webpack && npm run typegen",
"test": "node --experimental-vm-modules --expose-gc node_modules/jest/bin/jest.js --verbose",
"test": "node --experimental-vm-modules --expose-gc node_modules/jest/bin/jest.js --verbose --logHeapUsage",
"readme": "python ./docs/scripts/build_readme.py",
"docs-api": "node ./docs/scripts/generate.js",
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
Expand Down Expand Up @@ -55,10 +55,10 @@
},
"homepage": "https://github.com/huggingface/transformers.js#readme",
"dependencies": {
"@huggingface/jinja": "^0.4.1",
"onnxruntime-node": "1.21.0",
"onnxruntime-web": "1.22.0-dev.20250409-89f8206ba4",
"sharp": "^0.34.1"
"@huggingface/jinja": "^0.5.0",
"onnxruntime-node": "1.23.0-dev.20250612-70f14d7670",
"onnxruntime-web": "1.23.0-dev.20250612-70f14d7670",
"sharp": "^0.34.2"
},
"devDependencies": {
"@types/jest": "^29.5.14",
Expand Down
6 changes: 3 additions & 3 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from onnxruntime.quantization import QuantType, QuantizationMode
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.registry import IntegerOpsRegistry
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer

from . import float16
Expand Down Expand Up @@ -107,7 +107,7 @@ class QuantizationArguments:
},
)

# MatMul4BitsQuantizer
# MatMulNBitsQuantizer
is_symmetric: bool = field(
default=True,
metadata={"help": "Indicate whether to quantize the model symmetrically"},
Expand Down Expand Up @@ -234,7 +234,7 @@ def quantize_q4(
Quantize the weights of the model from float32 to 4-bit int
"""

quantizer = MatMul4BitsQuantizer(
quantizer = MatMulNBitsQuantizer(
model=model,
block_size=block_size,
is_symmetric=is_symmetric,
Expand Down
10 changes: 5 additions & 5 deletions scripts/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
transformers[torch]==4.49.0
onnxruntime==1.20.1
optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
onnx==1.17.0
transformers[torch]==4.52.4
onnxruntime==1.22.0
optimum==1.26.1
onnx==1.18.0
tqdm==4.67.1
onnxslim==0.1.48
onnxslim==0.1.52
26 changes: 19 additions & 7 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
coreml: 'coreml', // CoreML

webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
Expand All @@ -63,13 +64,15 @@ if (ORT_SYMBOL in globalThis) {
} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;

// Updated as of ONNX Runtime 1.20.1
// Updated as of ONNX Runtime 1.23.0-dev.20250612-70f14d7670
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
// | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | --------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | WebGPU (experimental) | ✔️ | ✔️ | ✔️ | ❌ | ✔️ | ✔️ |
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
// | CUDA | ❌ | ❌ | ✔️ (CUDA v12) | ❌ | ❌ | ❌ |
// | CoreML | ❌ | ❌ | ❌ | ❌ | ✔️ | ✔️ |
switch (process.platform) {
case 'win32': // Windows x64 and Windows arm64
supportedDevices.push('dml');
Expand All @@ -80,9 +83,11 @@ if (ORT_SYMBOL in globalThis) {
}
break;
case 'darwin': // MacOS x64 and MacOS arm64
supportedDevices.push('coreml');
break;
}

supportedDevices.push('webgpu');
supportedDevices.push('cpu');
defaultDevices = ['cpu'];
} else {
Expand Down Expand Up @@ -180,9 +185,16 @@ if (ONNX_ENV?.wasm) {
if (
// @ts-ignore Cannot find name 'ServiceWorkerGlobalScope'.ts(2304)
!(typeof ServiceWorkerGlobalScope !== 'undefined' && self instanceof ServiceWorkerGlobalScope)
&& env.backends.onnx.versions?.web
&& !ONNX_ENV.wasm.wasmPaths
) {
ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`;
const wasmPathPrefix = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${env.backends.onnx.versions.web}/dist/`;

ONNX_ENV.wasm.wasmPaths = apis.IS_SAFARI ? {
"mjs": `${wasmPathPrefix}/ort-wasm-simd-threaded.mjs`,
"wasm": `${wasmPathPrefix}/ort-wasm-simd-threaded.wasm`,
}
: wasmPathPrefix;
}

// TODO: Add support for loading WASM files from cached buffer when we upgrade to [email protected]
Expand Down
45 changes: 38 additions & 7 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ import url from 'node:url';

const VERSION = '3.5.2';

// Check if various APIs are available (depends on environment)
const IS_BROWSER_ENV = typeof window !== "undefined" && typeof window.document !== "undefined";
const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope';
const IS_WEB_CACHE_AVAILABLE = typeof self !== "undefined" && 'caches' in self;
const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;

const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be true for Bun, which does not have webgpu support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The native EP doesn't rely on navigator.gpu, so we can still run it in the same way:

import { pipeline } from "@huggingface/transformers";

// Create a feature-extraction pipeline
const extractor = await pipeline(
  "feature-extraction",
  "Xenova/all-MiniLM-L6-v2",
  { device: "webgpu" }
);

// Compute sentence embeddings
const sentences = ["Hello world", "This is an example sentence"];
const output = await extractor(sentences, { pooling: "mean", normalize: true });
console.log(output.tolist());

output:

$ bun run testing/bun.js 
dtype not specified for "model". Using the default dtype (fp32) for this device (webgpu).
2025-06-19 16:19:03.276 bun[51733:2036459] 2025-06-19 16:19:03.276206 [W:onnxruntime:, session_state.cc:1276 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2025-06-19 16:19:03.276 bun[51733:2036459] 2025-06-19 16:19:03.276244 [W:onnxruntime:, session_state.cc:1278 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
[
  [
    -0.03447728604078293, 0.031023163348436356, 0.006734936963766813, ...
  ], [
    0.06765685230493546, 0.06349590420722961, 0.04871312156319618, ...
  ]
]

const IS_FS_AVAILABLE = !isEmpty(fs);
Expand All @@ -44,6 +37,41 @@ const IS_PATH_AVAILABLE = !isEmpty(path);
const IS_DENO_RUNTIME = typeof globalThis.Deno !== 'undefined';
const IS_BUN_RUNTIME = typeof globalThis.Bun !== 'undefined';

// Check if various APIs are available (depends on environment)
const IS_BROWSER_ENV = typeof window !== "undefined" && typeof window.document !== "undefined";
const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope';
const IS_WEB_CACHE_AVAILABLE = typeof self !== "undefined" && 'caches' in self;
const IS_WEBGPU_AVAILABLE = IS_NODE_ENV || (typeof navigator !== 'undefined' && 'gpu' in navigator);
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;

/**
* Check if the current environment is Safari browser.
* Works in both browser and web worker contexts.
* @returns {boolean} Whether the current environment is Safari.
*/
const isSafari = () => {
// Check if we're in a browser environment
if (typeof navigator === 'undefined') {
return false;
}

const userAgent = navigator.userAgent;
const vendor = navigator.vendor || '';

// Safari has "Apple" in vendor string
const isAppleVendor = vendor.indexOf('Apple') > -1;

// Exclude Chrome on iOS (CriOS), Firefox on iOS (FxiOS),
// Edge on iOS (EdgiOS), and other browsers
const notOtherBrowser =
!userAgent.match(/CriOS|FxiOS|EdgiOS|OPiOS|mercury|brave/i) &&
!userAgent.includes('Chrome') &&
!userAgent.includes('Android');

return isAppleVendor && notOtherBrowser;
};
const IS_SAFARI = isSafari();

/**
* A read-only object containing information about the APIs available in the current environment.
*/
Expand All @@ -63,6 +91,9 @@ export const apis = Object.freeze({
/** Whether the WebNN API is available */
IS_WEBNN_AVAILABLE,

/** Whether we are running in a Safari browser */
IS_SAFARI,

/** Whether the Node.js process API is available */
IS_PROCESS_AVAILABLE,

Expand Down
35 changes: 22 additions & 13 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @param {boolean} [is_decoder=false] Whether the model is a decoder model.
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
async function getSession(pretrained_model_name_or_path, fileName, options, is_decoder = false) {
let custom_config = options.config?.['transformers.js_config'] ?? {};

let device = options.device ?? custom_config.device;
Expand Down Expand Up @@ -218,7 +219,14 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {

if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
} else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
} else if (
selectedDevice === 'webgpu' && (
// NOTE: Currently, we assume that the Native WebGPU EP always supports fp16. In future, we will add a check for this.
!apis.IS_NODE_ENV
&&
(selectedDtype === DATA_TYPES.fp16 && !(await isWebGpuFp16Supported()))
)
) {
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
}

Expand Down Expand Up @@ -316,7 +324,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}
}

if (selectedDevice === 'webgpu') {
if (is_decoder && selectedDevice === 'webgpu') {
const shapes = getKeyValueShapes(options.config, {
prefix: 'present',
});
Expand All @@ -342,13 +350,14 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @param {string} [decoder_name] The name of the decoder model, if any.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
async function constructSessions(pretrained_model_name_or_path, names, options, decoder_name = undefined) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options, name === decoder_name);
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
return [name, session];
})
Expand Down Expand Up @@ -1148,7 +1157,7 @@ export class PreTrainedModel extends Callable {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1159,7 +1168,7 @@ export class PreTrainedModel extends Callable {
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
}, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1178,7 +1187,7 @@ export class PreTrainedModel extends Callable {
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
}, options, 'decoder_model_merged'),
]);

} else if (modelType === MODEL_TYPES.ImageTextToText) {
Expand All @@ -1191,7 +1200,7 @@ export class PreTrainedModel extends Callable {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1204,7 +1213,7 @@ export class PreTrainedModel extends Callable {
decoder_model_merged: 'decoder_model_merged',
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1216,7 +1225,7 @@ export class PreTrainedModel extends Callable {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
}, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1231,7 +1240,7 @@ export class PreTrainedModel extends Callable {
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1243,7 +1252,7 @@ export class PreTrainedModel extends Callable {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'model',
vision_encoder: 'vision_encoder',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand Down
1 change: 1 addition & 0 deletions src/utils/devices.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const DEVICE_TYPES = Object.freeze({
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
coreml: 'coreml', // CoreML

webnn: 'webnn', // WebNN (default)
'webnn-npu': 'webnn-npu', // WebNN NPU
Expand Down
6 changes: 3 additions & 3 deletions tests/models/florence2/test_modeling_florence2.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export default () => {
{
const inputs = await processor(image, texts[0]);
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n]]);
expect(generate_ids.tolist()).toEqual([[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n]]);
}
},
MAX_TEST_EXECUTION_TIME,
Expand All @@ -68,8 +68,8 @@ export default () => {

const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
]);
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export default () => {
expect(pred_boxes.dims).toEqual([1, num_queries, 4]);
expect(logits.max().item()).toBeCloseTo(56.237613677978516, 2);
expect(logits.min().item()).toEqual(-Infinity);
expect(pred_boxes.mean().item()).toEqual(0.2500016987323761);
expect(pred_boxes.mean().item()).toBeCloseTo(0.2500016987323761, 4);
},
MAX_TEST_EXECUTION_TIME,
);
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/test_pipelines_depth_estimation.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export default () => {
async () => {
const output = await pipe(images[0]);
expect(output.predicted_depth.dims).toEqual([224, 224]);
expect(output.predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 6);
expect(output.predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 4);
expect(output.depth.size).toEqual(images[0].size);
},
MAX_TEST_EXECUTION_TIME,
Expand All @@ -40,10 +40,10 @@ export default () => {
const output = await pipe(images);
expect(output).toHaveLength(images.length);
expect(output[0].predicted_depth.dims).toEqual([224, 224]);
expect(output[0].predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 6);
expect(output[0].predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 4);
expect(output[0].depth.size).toEqual(images[0].size);
expect(output[1].predicted_depth.dims).toEqual([224, 224]);
expect(output[1].predicted_depth.mean().item()).toBeCloseTo(0.0000014548650142387487, 6);
expect(output[1].predicted_depth.mean().item()).toBeCloseTo(0.0000014548650142387487, 4);
expect(output[1].depth.size).toEqual(images[1].size);
},
MAX_TEST_EXECUTION_TIME,
Expand Down
Loading
Loading