Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ These examples demonstrate various capabilities via WebLLM's OpenAI-like API.
#### Others

- [logit-processor](logit-processor): while `logit_bias` is supported, we additionally support stateful logit processing where users can specify their own rules. We also expose low-level API `forwardTokensAndSample()`.
- [cache-usage](cache-usage): demonstrates how WebLLM supports both the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache) and [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), and
users can pick with `appConfig.useIndexedDBCache`. Also demonstrates various cache utils such as checking
- [cache-usage](cache-usage): demonstrates how WebLLM supports multiple cache backends. Choose between the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache), [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension via `appConfig.cacheBackend`. Also demonstrates various cache utils such as checking
whether a model is cached, deleting a model's weights from cache, deleting a model library wasm from cache, etc.
- [simple-chat-upload](simple-chat-upload): demonstrates how to upload local models to WebLLM instead of downloading via a URL link

Expand Down
7 changes: 5 additions & 2 deletions examples/cache-usage/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# WebLLM Cache Usage

WebLLM supports both the Cache API and IndexedDB, which you can specify via `AppConfig.useIndexedDBCache`.
This folder provides an example on how Cache and IndexedDB Cache are used in WebLLM. We also
WebLLM supports multiple persistent cache backends. You can pick the classic Cache API, IndexedDB, or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension by
setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`.
This folder provides an example on how different caches are used in WebLLM. We also
demonstrate the utility cache functions such as deleting models, checking if models are in cache, etc.

> **Note:** The cross-origin backend requires Chrome's cross-origin storage experiment or the community browser extension to be installed and granted access to the domains that host your model artifacts (e.g. huggingface.co).

For more information about the two caches, see: https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser.

To inspect the downloaded artifacts in your browser, open up developer console, go to application,
Expand Down
6 changes: 3 additions & 3 deletions examples/cache-usage/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/cache_usage.html --port 8888",
"start": "parcel src/cache_usage.html --port 8889",
"build": "parcel build src/cache_usage.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"parcel": "2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "^0.2.79"
"@mlc-ai/web-llm": "file:../.."
}
}
185 changes: 159 additions & 26 deletions src/cache_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,153 @@ import {
ChatConfig,
ModelRecord,
prebuiltAppConfig,
getCacheBackend,
} from "./config";
import { cleanModelUrl } from "./support";
import { ModelNotFoundError, UnsupportedTokenizerFilesError } from "./error";
import { Tokenizer } from "@mlc-ai/web-tokenizers";
import CrossOriginStorage from "./cross_origin_storage";
import CrossOriginStorageCache from "./cross_origin_storage_cache";

type CacheScope = "webllm/model" | "webllm/config" | "webllm/wasm";

let crossOriginUnavailableLogged = false;
let crossOriginAvailabilityWait: Promise<void> | null = null;

function scheduleCrossOriginFallbackWarning(
logger: (msg: string) => void,
): void {
if (crossOriginUnavailableLogged || crossOriginAvailabilityWait) {
return;
}
crossOriginAvailabilityWait = (async () => {
const availableSoon = await CrossOriginStorage.waitForAvailability();
crossOriginAvailabilityWait = null;
if (availableSoon || crossOriginUnavailableLogged) {
return;
}
logger(
"Cross-origin storage backend is not yet available; temporarily falling back to the Cache API.",
);
crossOriginUnavailableLogged = true;
})();
}

function shouldUseCrossOrigin(appConfig: AppConfig): boolean {
return (
getCacheBackend(appConfig) === "cross-origin" &&
CrossOriginStorage.isAvailable()
);
}

export function getArtifactCache(
scope: CacheScope,
appConfig: AppConfig,
logger: (msg: string) => void = console.warn,
): tvmjs.ArtifactCacheTemplate {
const backend = getCacheBackend(appConfig);
if (backend === "cross-origin") {
if (CrossOriginStorage.isAvailable()) {
return new CrossOriginStorageCache(scope);
}
scheduleCrossOriginFallbackWarning(logger);
}
if (backend === "indexeddb") {
return new tvmjs.ArtifactIndexedDBCache(scope);
}
return new tvmjs.ArtifactCache(scope);
}

async function hasTensorCache(
cache: tvmjs.ArtifactCacheTemplate,
tensorCacheUrl: string,
): Promise<boolean> {
const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href;
const hasManifest = await cache.hasAllKeys([jsonUrl]);
if (!hasManifest) {
return false;
}
const manifest = await cache.fetchWithCache(jsonUrl, "json");
const records = manifest?.records ?? [];
if (!Array.isArray(records) || records.length === 0) {
return false;
}
const shardUrls = records.map(
(entry: { dataPath: string }) =>
new URL(entry.dataPath, tensorCacheUrl).href,
);
return cache.hasAllKeys(shardUrls);
}

async function deleteTensorCacheEntries(
cache: tvmjs.ArtifactCacheTemplate,
tensorCacheUrl: string,
): Promise<void> {
const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href;
const hasManifest = await cache.hasAllKeys([jsonUrl]);
if (!hasManifest) {
return;
}
let manifest: { records?: Array<{ dataPath: string }> };
try {
manifest = await cache.fetchWithCache(jsonUrl, "json");
} catch (err) {
return;
}
const records = manifest?.records ?? [];
await Promise.all(
records.map(async (entry) => {
if (!entry?.dataPath) {
return;
}
const dataUrl = new URL(entry.dataPath, tensorCacheUrl).href;
await cache.deleteInCache(dataUrl);
}),
);
await cache.deleteInCache(jsonUrl);
}

export async function fetchModelArtifacts(
tvm: tvmjs.Instance,
tensorCacheUrl: string,
device: tvmjs.DLDevice,
appConfig: AppConfig,
signal?: AbortSignal,
): Promise<any> {
if (!shouldUseCrossOrigin(appConfig)) {
const backend = getCacheBackend(appConfig);
const cacheType = backend === "indexeddb" ? "indexeddb" : "cache";
return tvm.fetchTensorCache(
tensorCacheUrl,
device,
"webllm/model",
cacheType,
signal,
);
}

const artifactCache = getArtifactCache("webllm/model", appConfig);
const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href;
const manifest = await artifactCache.fetchWithCache(jsonUrl, "json", signal);
const records = (
Array.isArray(manifest?.records) ? manifest.records : []
) as Array<any>;
await (tvm as any).fetchTensorCacheInternal(
tensorCacheUrl,
records,
device,
artifactCache,
signal,
);
if (manifest?.metadata !== undefined) {
const runtime = tvm as any;
runtime.cacheMetadata = {
...runtime.cacheMetadata,
...(manifest.metadata as Record<string, unknown>),
};
}
return manifest;
}

function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord {
const matchedItem = appConfig?.model_list.find(
Expand All @@ -28,8 +171,13 @@ export async function hasModelInCache(
}
const modelRecord = findModelRecord(modelId, appConfig);
const modelUrl = cleanModelUrl(modelRecord.model);
const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache";
return tvmjs.hasNDArrayInCache(modelUrl, "webllm/model", cacheType);
if (shouldUseCrossOrigin(appConfig)) {
const cache = getArtifactCache("webllm/model", appConfig);
return hasTensorCache(cache, modelUrl);
}
const backend = getCacheBackend(appConfig);
const cacheType = backend === "indexeddb" ? "indexeddb" : "cache";
return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType);
}

export async function deleteModelAllInfoInCache(
Expand Down Expand Up @@ -58,13 +206,13 @@ export async function deleteModelInCache(
}
const modelRecord = findModelRecord(modelId, appConfig);
const modelUrl = cleanModelUrl(modelRecord.model);
let modelCache: tvmjs.ArtifactCacheTemplate;
if (appConfig.useIndexedDBCache) {
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "indexeddb");
modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model");
const modelCache = getArtifactCache("webllm/model", appConfig);
if (shouldUseCrossOrigin(appConfig)) {
await deleteTensorCacheEntries(modelCache, modelUrl);
} else {
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "cache");
modelCache = new tvmjs.ArtifactCache("webllm/model");
const backend = getCacheBackend(appConfig);
const cacheType = backend === "indexeddb" ? "indexeddb" : "cache";
await tvmjs.deleteTensorCache(modelUrl, "webllm/model", cacheType);
}
await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href);
await modelCache.deleteInCache(new URL("tokenizer.json", modelUrl).href);
Expand All @@ -79,12 +227,7 @@ export async function deleteChatConfigInCache(
appConfig = prebuiltAppConfig;
}
const modelRecord = findModelRecord(modelId, appConfig);
let configCache: tvmjs.ArtifactCacheTemplate;
if (appConfig.useIndexedDBCache) {
configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config");
} else {
configCache = new tvmjs.ArtifactCache("webllm/config");
}
const configCache = getArtifactCache("webllm/config", appConfig);
const modelUrl = cleanModelUrl(modelRecord.model);
const configUrl = new URL("mlc-chat-config.json", modelUrl).href;
await configCache.deleteInCache(configUrl);
Expand All @@ -99,12 +242,7 @@ export async function deleteModelWasmInCache(
appConfig = prebuiltAppConfig;
}
const modelRecord = findModelRecord(modelId, appConfig);
let wasmCache: tvmjs.ArtifactCacheTemplate;
if (appConfig.useIndexedDBCache) {
wasmCache = new tvmjs.ArtifactIndexedDBCache("webllm/wasm");
} else {
wasmCache = new tvmjs.ArtifactCache("webllm/wasm");
}
const wasmCache = getArtifactCache("webllm/wasm", appConfig);
await wasmCache.deleteInCache(modelRecord.model_lib);
}

Expand All @@ -122,12 +260,7 @@ export async function asyncLoadTokenizer(
appConfig: AppConfig,
logger: (msg: string) => void = console.log,
): Promise<Tokenizer> {
let modelCache: tvmjs.ArtifactCacheTemplate;
if (appConfig.useIndexedDBCache) {
modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model");
} else {
modelCache = new tvmjs.ArtifactCache("webllm/model");
}
const modelCache = getArtifactCache("webllm/model", appConfig, logger);

if (config.tokenizer_files.includes("tokenizer.json")) {
const url = new URL("tokenizer.json", baseUrl).href;
Expand Down
17 changes: 13 additions & 4 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,24 @@ export interface ModelRecord {
* passed to the load.
*
* @param model_list: models to be used.
* @param useIndexedDBCache: if true, will use IndexedDBCache to cache models and other artifacts.
* If false or unspecified, will use the Cache API. For more information of the two, see:
* @param cacheBackend: the backend to use for caching models and other artifacts.
* If unspecified, will use the Cache API. For more information, see:
* https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser
*
* @note Note that the Cache API is more well-tested in WebLLM as of now.
*/
export type CacheBackend = "cache" | "indexeddb" | "cross-origin";

export interface AppConfig {
model_list: Array<ModelRecord>;
useIndexedDBCache?: boolean;
cacheBackend?: CacheBackend;
}

export function getCacheBackend(appConfig: AppConfig): CacheBackend {
if (appConfig.cacheBackend !== undefined) {
return appConfig.cacheBackend;
}
return "cache";
}

/**
Expand Down Expand Up @@ -310,7 +319,7 @@ export const functionCallingModelIds = [
* current WebLLM npm version.
*/
export const prebuiltAppConfig: AppConfig = {
useIndexedDBCache: false,
cacheBackend: "cache",
model_list: [
// Llama-3.2
{
Expand Down
Loading