Skip to content
Snippets Groups Projects
Unverified Commit b48bcc3a authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

feat: support custom `@xenova/transformers` (#1232)

parent fa01fa20
No related branches found
No related tags found
No related merge requests found
Showing
with 282 additions and 127 deletions
---
"@llamaindex/core": patch
"@llamaindex/env": patch
"llamaindex": patch
---
feat: add `load-transformers` event type when loading `@xenova/transformers` module
This would benefit user who want to customize the transformer env.
......@@ -128,16 +128,29 @@ export class CallbackManager {
dispatchEvent<K extends keyof LlamaIndexEventMaps>(
event: K,
detail: LlamaIndexEventMaps[K],
sync = false,
) {
const cbs = this.#handlers.get(event);
if (!cbs) {
return;
}
queueMicrotask(() => {
if (typeof queueMicrotask === "undefined") {
console.warn(
"queueMicrotask is not available, dispatching synchronously",
);
sync = true;
}
if (sync) {
cbs.forEach((handler) =>
handler(LlamaIndexCustomEvent.fromEvent(event, { ...detail })),
);
});
} else {
queueMicrotask(() => {
cbs.forEach((handler) =>
handler(LlamaIndexCustomEvent.fromEvent(event, { ...detail })),
);
});
}
}
}
......
......@@ -74,16 +74,18 @@
"@aws-crypto/sha256-js": "^5.2.0",
"@swc/cli": "^0.4.0",
"@swc/core": "^1.7.22",
"@xenova/transformers": "^2.17.2",
"concurrently": "^8.2.2",
"pathe": "^1.1.2",
"tiktoken": "^1.0.16",
"vitest": "^2.0.5"
},
"dependencies": {
"@types/lodash": "^4.17.7",
"@types/node": "^22.5.1"
},
"peerDependencies": {
"@aws-crypto/sha256-js": "^5.2.0",
"@xenova/transformers": "^2.17.2",
"js-tiktoken": "^1.0.12",
"pathe": "^1.1.2",
"tiktoken": "^1.0.15"
......@@ -92,8 +94,17 @@
"@aws-crypto/sha256-js": {
"optional": true
},
"@xenova/transformers": {
"optional": true
},
"pathe": {
"optional": true
},
"tiktoken": {
"optional": true
},
"js-tiktoken": {
"optional": true
}
}
}
......@@ -6,6 +6,12 @@
import "./global-check.js";
export * from "./web-polyfill.js";
export {
loadTransformers,
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./multi-model/index.browser.js";
export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/js.js";
// @ts-expect-error
......
......@@ -6,4 +6,10 @@
import "./global-check.js";
export * from "./node-polyfill.js";
export {
loadTransformers,
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./multi-model/index.non-nodejs.js";
export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/js.js";
......@@ -33,6 +33,12 @@ export function createSHA256(): SHA256 {
};
}
export {
loadTransformers,
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./multi-model/index.js";
export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/node.js";
export {
AsyncLocalStorage,
......
......@@ -13,4 +13,10 @@ export function getEnv(name: string): string | undefined {
return INTERNAL_ENV[name];
}
export {
loadTransformers,
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./multi-model/index.non-nodejs.js";
export { Tokenizers, tokenizers, type Tokenizer } from "./tokenizers/js.js";
import { getTransformers, setTransformers, type OnLoad } from "./shared.js";
export {
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./shared.js";
export async function loadTransformers(onLoad: OnLoad) {
if (getTransformers() === null) {
setTransformers(
// @ts-expect-error
await import("https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2"),
);
} else {
return getTransformers()!;
}
const transformer = getTransformers()!;
onLoad(transformer);
return transformer;
}
import { getTransformers, setTransformers, type OnLoad } from "./shared.js";
export {
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./shared.js";
export async function loadTransformers(onLoad: OnLoad) {
if (getTransformers() === null) {
/**
* If you see this warning, it means that the current environment does not support the transformer.
* because "@xeonva/transformers" highly depends on Node.js APIs.
*
* One possible solution is to fix their implementation to make it work in the non-Node.js environment,
* but it's not worth the effort because Edge Runtime and Cloudflare Workers are not the for heavy Machine Learning task.
*
* Or you can provide an RPC server that runs the transformer in a Node.js environment.
* Or you just run the code in a Node.js environment.
*
* Refs: https://github.com/xenova/transformers.js/issues/309
*/
console.warn(
'"@xenova/transformers" is not officially supported in this environment, some features may not work as expected.',
);
setTransformers(
// @ts-expect-error
await import("@xenova/transformers/dist/transformers"),
);
} else {
return getTransformers()!;
}
const transformer = getTransformers()!;
onLoad(transformer);
return transformer;
}
import { getTransformers, setTransformers, type OnLoad } from "./shared.js";
export {
setTransformers,
type LoadTransformerEvent,
type OnLoad,
} from "./shared.js";
export async function loadTransformers(onLoad: OnLoad) {
if (getTransformers() === null) {
setTransformers(await import("@xenova/transformers"));
} else {
return getTransformers()!;
}
const transformer = getTransformers()!;
onLoad(transformer);
return transformer;
}
let transformer: typeof import("@xenova/transformers") | null = null;
export function getTransformers() {
return transformer;
}
export function setTransformers(t: typeof import("@xenova/transformers")) {
transformer = t;
}
export type OnLoad = (
transformer: typeof import("@xenova/transformers"),
) => void;
export type LoadTransformerEvent = {
transformer: typeof import("@xenova/transformers");
};
:root {
--max-width: 1100px;
--border-radius: 12px;
--font-mono: ui-monospace, Menlo, Monaco, "Cascadia Mono", "Segoe UI Mono",
"Roboto Mono", "Oxygen Mono", "Ubuntu Monospace", "Source Code Pro",
"Fira Mono", "Droid Sans Mono", "Courier New", monospace;
--foreground-rgb: 0, 0, 0;
--background-start-rgb: 214, 219, 220;
--background-end-rgb: 255, 255, 255;
--primary-glow: conic-gradient(
from 180deg at 50% 50%,
#16abff33 0deg,
#0885ff33 55deg,
#54d6ff33 120deg,
#0071ff33 160deg,
transparent 360deg
);
--secondary-glow: radial-gradient(
rgba(255, 255, 255, 1),
rgba(255, 255, 255, 0)
);
--tile-start-rgb: 239, 245, 249;
--tile-end-rgb: 228, 232, 233;
--tile-border: conic-gradient(
#00000080,
#00000040,
#00000030,
#00000020,
#00000010,
#00000010,
#00000080
);
--callout-rgb: 238, 240, 241;
--callout-border-rgb: 172, 175, 176;
--card-rgb: 180, 185, 188;
--card-border-rgb: 131, 134, 135;
}
@media (prefers-color-scheme: dark) {
:root {
--foreground-rgb: 255, 255, 255;
--background-start-rgb: 0, 0, 0;
--background-end-rgb: 0, 0, 0;
--primary-glow: radial-gradient(rgba(1, 65, 255, 0.4), rgba(1, 65, 255, 0));
--secondary-glow: linear-gradient(
to bottom right,
rgba(1, 65, 255, 0),
rgba(1, 65, 255, 0),
rgba(1, 65, 255, 0.3)
);
--tile-start-rgb: 2, 13, 46;
--tile-end-rgb: 2, 5, 19;
--tile-border: conic-gradient(
#ffffff80,
#ffffff40,
#ffffff30,
#ffffff20,
#ffffff10,
#ffffff10,
#ffffff80
);
--callout-rgb: 20, 20, 20;
--callout-border-rgb: 108, 108, 108;
--card-rgb: 100, 100, 100;
--card-border-rgb: 200, 200, 200;
}
}
* {
box-sizing: border-box;
padding: 0;
margin: 0;
}
html,
body {
max-width: 100vw;
overflow-x: hidden;
}
body {
color: rgb(var(--foreground-rgb));
background: linear-gradient(
to bottom,
transparent,
rgb(var(--background-end-rgb))
)
rgb(var(--background-start-rgb));
}
a {
color: inherit;
text-decoration: none;
}
@media (prefers-color-scheme: dark) {
html {
color-scheme: dark;
}
}
// test runtime
import "llamaindex";
import { ClipEmbedding } from "llamaindex/embeddings/ClipEmbedding";
import { ClipEmbedding } from "llamaindex";
import "llamaindex/readers/SimpleDirectoryReader";
// @ts-expect-error
......
import { ClipEmbedding, ImageNode } from "llamaindex";
import type { LoadTransformerEvent } from "@llamaindex/env";
import { setTransformers } from "@llamaindex/env";
import { ClipEmbedding, ImageNode, Settings } from "llamaindex";
import assert from "node:assert";
import { test } from "node:test";
import { type Mock, test } from "node:test";
let callback: Mock<(event: any) => void>;
test.before(() => {
callback = test.mock.fn((event: any) => {
const { transformer } = event.detail as LoadTransformerEvent;
assert.ok(transformer);
assert.ok(transformer.env);
});
Settings.callbackManager.on("load-transformers", callback);
});
test.beforeEach(() => {
callback.mock.resetCalls();
});
await test("clip embedding", async (t) => {
await t.test("should trigger load transformer event", async () => {
const nodes = [
new ImageNode({
image: new URL(
"../../fixtures/img/llamaindex-white.png",
import.meta.url,
),
}),
];
assert.equal(callback.mock.callCount(), 0);
const clipEmbedding = new ClipEmbedding();
assert.equal(callback.mock.callCount(), 0);
const result = await clipEmbedding(nodes);
assert.strictEqual(result.length, 1);
assert.equal(callback.mock.callCount(), 1);
});
await t.test("init & get image embedding", async () => {
const clipEmbedding = new ClipEmbedding();
const imgUrl = new URL(
......@@ -27,4 +60,25 @@ await test("clip embedding", async (t) => {
assert.strictEqual(result.length, 1);
assert.ok(result[0]!.embedding);
});
await t.test("custom transformer", async () => {
const transformers = await import("@xenova/transformers");
const getter = test.mock.fn((t, k, r) => {
return Reflect.get(t, k, r);
});
setTransformers(
new Proxy(transformers, {
get: getter,
}),
);
const clipEmbedding = new ClipEmbedding();
const imgUrl = new URL(
"../../fixtures/img/llamaindex-white.png",
import.meta.url,
);
assert.equal(getter.mock.callCount(), 0);
const vec = await clipEmbedding.getImageEmbedding(imgUrl);
assert.ok(vec);
assert.ok(getter.mock.callCount() > 0);
});
});
......@@ -33,8 +33,8 @@
"@llamaindex/cloud": "workspace:*",
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*",
"@llamaindex/openai": "workspace:*",
"@llamaindex/groq": "workspace:*",
"@llamaindex/openai": "workspace:*",
"@mistralai/mistralai": "^1.0.4",
"@mixedbread-ai/sdk": "^2.2.11",
"@pinecone-database/pinecone": "^3.0.2",
......@@ -43,7 +43,6 @@
"@types/node": "^22.5.1",
"@types/papaparse": "^5.3.14",
"@types/pg": "^8.11.8",
"@xenova/transformers": "^2.17.2",
"@zilliz/milvus2-sdk-node": "^2.4.6",
"ajv": "^8.17.1",
"assemblyai": "^4.7.0",
......@@ -91,6 +90,7 @@
"@notionhq/client": "^2.2.15",
"@swc/cli": "^0.4.0",
"@swc/core": "^1.7.22",
"@xenova/transformers": "^2.17.2",
"concurrently": "^8.2.2",
"glob": "^11.0.0",
"pg": "^8.12.0",
......
......@@ -12,6 +12,7 @@ import {
type NodeParser,
SentenceSplitter,
} from "@llamaindex/core/node-parser";
import type { LoadTransformerEvent } from "@llamaindex/env";
import { AsyncLocalStorage, getEnv } from "@llamaindex/env";
import type { ServiceContext } from "./ServiceContext.js";
import {
......@@ -20,6 +21,12 @@ import {
withEmbeddedModel,
} from "./internal/settings/EmbedModel.js";
declare module "@llamaindex/core/global" {
interface LlamaIndexEventMaps {
"load-transformers": LoadTransformerEvent;
}
}
export type PromptConfig = {
llm?: string;
lang?: string;
......
import { MultiModalEmbedding } from "@llamaindex/core/embeddings";
import type { ImageType } from "@llamaindex/core/schema";
import _ from "lodash";
import { lazyLoadTransformers } from "../internal/deps/transformers.js";
// only import type, to avoid bundling error
import { loadTransformers } from "@llamaindex/env";
import type {
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
PreTrainedTokenizer,
Processor,
} from "@xenova/transformers";
import { Settings } from "../Settings.js";
async function readImage(input: ImageType) {
const { RawImage } = await lazyLoadTransformers();
const { RawImage } = await loadTransformers((transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
});
if (input instanceof Blob) {
return await RawImage.fromBlob(input);
} else if (_.isString(input) || input instanceof URL) {
......@@ -40,7 +49,15 @@ export class ClipEmbedding extends MultiModalEmbedding {
}
async getTokenizer() {
const { AutoTokenizer } = await lazyLoadTransformers();
const { AutoTokenizer } = await loadTransformers((transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
});
if (!this.tokenizer) {
this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType);
}
......@@ -48,7 +65,15 @@ export class ClipEmbedding extends MultiModalEmbedding {
}
async getProcessor() {
const { AutoProcessor } = await lazyLoadTransformers();
const { AutoProcessor } = await loadTransformers((transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
});
if (!this.processor) {
this.processor = await AutoProcessor.from_pretrained(this.modelType);
}
......@@ -56,7 +81,17 @@ export class ClipEmbedding extends MultiModalEmbedding {
}
async getVisionModel() {
const { CLIPVisionModelWithProjection } = await lazyLoadTransformers();
const { CLIPVisionModelWithProjection } = await loadTransformers(
(transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
},
);
if (!this.visionModel) {
this.visionModel = await CLIPVisionModelWithProjection.from_pretrained(
this.modelType,
......@@ -67,7 +102,17 @@ export class ClipEmbedding extends MultiModalEmbedding {
}
async getTextModel() {
const { CLIPTextModelWithProjection } = await lazyLoadTransformers();
const { CLIPTextModelWithProjection } = await loadTransformers(
(transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
},
);
if (!this.textModel) {
this.textModel = await CLIPTextModelWithProjection.from_pretrained(
this.modelType,
......
import { HfInference } from "@huggingface/inference";
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { lazyLoadTransformers } from "../internal/deps/transformers.js";
import { loadTransformers } from "@llamaindex/env";
import { Settings } from "../Settings.js";
export enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "Xenova/all-MiniLM-L6-v2",
......@@ -33,7 +34,15 @@ export class HuggingFaceEmbedding extends BaseEmbedding {
async getExtractor() {
if (!this.extractor) {
const { pipeline } = await lazyLoadTransformers();
const { pipeline } = await loadTransformers((transformer) => {
Settings.callbackManager.dispatchEvent(
"load-transformers",
{
transformer,
},
true,
);
});
this.extractor = await pipeline("feature-extraction", this.modelType, {
quantized: this.quantized,
});
......
......@@ -9,3 +9,5 @@ export * from "./MixedbreadAIEmbeddings.js";
export { OllamaEmbedding } from "./OllamaEmbedding.js";
export * from "./OpenAIEmbedding.js";
export { TogetherEmbedding } from "./together.js";
// ClipEmbedding might not work in non-node.js runtime, but it doesn't have side effects
export { ClipEmbedding, ClipEmbeddingModelType } from "./ClipEmbedding.js";
......@@ -2,10 +2,6 @@ export * from "./index.edge.js";
export * from "./readers/index.js";
export * from "./storage/index.js";
// Exports modules that doesn't support non-node.js runtime
export {
ClipEmbedding,
ClipEmbeddingModelType,
} from "./embeddings/ClipEmbedding.js";
export {
HuggingFaceEmbedding,
HuggingFaceEmbeddingModelType,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment