diff --git a/apps/docs/docs/modules/embeddings/available_embeddings/deepinfra.md b/apps/docs/docs/modules/embeddings/available_embeddings/deepinfra.md new file mode 100644 index 0000000000000000000000000000000000000000..6f049ea864f319d2ef1ca5485db2bc2d4cc2e554 --- /dev/null +++ b/apps/docs/docs/modules/embeddings/available_embeddings/deepinfra.md @@ -0,0 +1,79 @@ +# DeepInfra + +To use DeepInfra embeddings, you need to import `DeepInfraEmbedding` from llamaindex. +Check out available embedding models [here](https://deepinfra.com/models/embeddings). + +```ts +import { + DeepInfraEmbedding, + Settings, + Document, + VectorStoreIndex, +} from "llamaindex"; + +// Update Embed Model +Settings.embedModel = new DeepInfraEmbedding(); + +const document = new Document({ text: essay, id_: "essay" }); + +const index = await VectorStoreIndex.fromDocuments([document]); + +const queryEngine = index.asQueryEngine(); + +const query = "What is the meaning of life?"; + +const results = await queryEngine.query({ + query, +}); +``` + +By default, DeepInfraEmbedding is using the sentence-transformers/clip-ViT-B-32 model. You can change the model by passing the model parameter to the constructor. +For example: + +```ts +import { DeepInfraEmbedding } from "llamaindex"; + +const model = "intfloat/e5-large-v2"; +Settings.embedModel = new DeepInfraEmbedding({ + model, +}); +``` + +You can also set the `maxRetries` and `timeout` parameters when initializing `DeepInfraEmbedding` for better control over the request behavior. + +For example: + +```ts +import { DeepInfraEmbedding, Settings } from "llamaindex"; + +const model = "intfloat/e5-large-v2"; +const maxRetries = 5; +const timeout = 5000; // 5 seconds + +Settings.embedModel = new DeepInfraEmbedding({ + model, + maxRetries, + timeout, +}); +``` + +Standalone usage: + +```ts +import { DeepInfraEmbedding } from "llamaindex"; +import { config } from "dotenv"; +// For standalone usage, you need to configure DEEPINFRA_API_TOKEN in .env file +config(); + +const main = async () => { + const model = "intfloat/e5-large-v2"; + const embeddings = new DeepInfraEmbedding({ model }); + const text = "What is the meaning of life?"; + const response = await embeddings.embed([text]); + console.log(response); +}; + +main(); +``` + +For questions or feedback, please contact us at [feedback@deepinfra.com](mailto:feedback@deepinfra.com) diff --git a/examples/deepinfra/embedding.ts b/examples/deepinfra/embedding.ts new file mode 100644 index 0000000000000000000000000000000000000000..809fd7fe694d7d753263df8ec6cbabc62369cc71 --- /dev/null +++ b/examples/deepinfra/embedding.ts @@ -0,0 +1,17 @@ +import { DeepInfraEmbedding } from "llamaindex"; + +async function main() { + // API token can be provided as an environment variable too + // using DEEPINFRA_API_TOKEN variable + const apiToken = "YOUR_API_TOKEN" ?? process.env.DEEPINFRA_API_TOKEN; + const model = "BAAI/bge-large-en-v1.5"; + const embedModel = new DeepInfraEmbedding({ + model, + apiToken, + }); + const texts = ["hello", "world"]; + const embeddings = await embedModel.getTextEmbeddingsBatch(texts); + console.log(`\nWe have ${embeddings.length} embeddings`); +} + +main().catch(console.error); diff --git a/packages/core/src/embeddings/DeepInfraEmbedding.ts b/packages/core/src/embeddings/DeepInfraEmbedding.ts new file mode 100644 index 0000000000000000000000000000000000000000..59f1c827e4ff166526b94fae3ba9e6d22c8546c1 --- /dev/null +++ b/packages/core/src/embeddings/DeepInfraEmbedding.ts @@ -0,0 +1,152 @@ +import { getEnv } from "@llamaindex/env"; +import type { MessageContentDetail } from "../llm/index.js"; +import { extractSingleText } from "../llm/utils.js"; +import { BaseEmbedding } from "./types.js"; + +const DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32"; + +const API_TOKEN_ENV_VARIABLE_NAME = "DEEPINFRA_API_TOKEN"; + +const API_ROOT = "https://api.deepinfra.com/v1/inference"; + +const DEFAULT_TIMEOUT = 60 * 1000; + +const DEFAULT_MAX_RETRIES = 5; + +export interface DeepInfraEmbeddingResponse { + embeddings: number[][]; + request_id: string; + inference_status: InferenceStatus; +} + +export interface InferenceStatus { + status: string; + runtime_ms: number; + cost: number; + tokens_input: number; +} + +const mapPrefixWithInputs = (prefix: string, inputs: string[]): string[] => { + return inputs.map((input) => `${prefix} ${input}`); +}; + +/** + * DeepInfraEmbedding is an alias for DeepInfra that implements the BaseEmbedding interface. + */ +export class DeepInfraEmbedding extends BaseEmbedding { + /** + * DeepInfra model to use + * @default "sentence-transformers/clip-ViT-B-32" + * @see https://deepinfra.com/models/embeddings + */ + model: string; + + /** + * DeepInfra API token + * @see https://deepinfra.com/dash/api_keys + * If not provided, it will try to get the token from the environment variable `DEEPINFRA_API_TOKEN` + * + */ + apiToken: string; + + /** + * Prefix to add to the query + * @default "" + */ + queryPrefix: string; + + /** + * Prefix to add to the text + * @default "" + */ + textPrefix: string; + + /** + * + * @default 5 + */ + maxRetries: number; + + /** + * + * @default 60 * 1000 + */ + timeout: number; + + constructor(init?: Partial<DeepInfraEmbedding>) { + super(); + + this.model = init?.model ?? DEFAULT_MODEL; + this.apiToken = init?.apiToken ?? getEnv(API_TOKEN_ENV_VARIABLE_NAME) ?? ""; + this.queryPrefix = init?.queryPrefix ?? ""; + this.textPrefix = init?.textPrefix ?? ""; + this.maxRetries = init?.maxRetries ?? DEFAULT_MAX_RETRIES; + this.timeout = init?.timeout ?? DEFAULT_TIMEOUT; + } + + async getTextEmbedding(text: string): Promise<number[]> { + const texts = mapPrefixWithInputs(this.textPrefix, [text]); + const embeddings = await this.getDeepInfraEmbedding(texts); + return embeddings[0]; + } + + async getQueryEmbedding( + query: MessageContentDetail, + ): Promise<number[] | null> { + const text = extractSingleText(query); + if (text) { + const queries = mapPrefixWithInputs(this.queryPrefix, [text]); + const embeddings = await this.getDeepInfraEmbedding(queries); + return embeddings[0]; + } else { + return null; + } + } + + async getTextEmbeddings(texts: string[]): Promise<number[][]> { + const textsWithPrefix = mapPrefixWithInputs(this.textPrefix, texts); + return await this.getDeepInfraEmbedding(textsWithPrefix); + } + + async getQueryEmbeddings(queries: string[]): Promise<number[][]> { + const queriesWithPrefix = mapPrefixWithInputs(this.queryPrefix, queries); + return await this.getDeepInfraEmbedding(queriesWithPrefix); + } + + private async getDeepInfraEmbedding(inputs: string[]): Promise<number[][]> { + const url = this.getUrl(this.model); + + for (let attempt = 0; attempt < this.maxRetries; attempt++) { + const controller = new AbortController(); + const id = setTimeout(() => controller.abort(), this.timeout); + + try { + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiToken}`, + }, + body: JSON.stringify({ inputs }), + signal: controller.signal, + }); + if (!response.ok) { + throw new Error(`Request failed with status ${response.status}`); + } + + const responseJson: DeepInfraEmbeddingResponse = await response.json(); + return responseJson.embeddings; + } catch (error) { + console.error(`Attempt ${attempt + 1} failed: ${error}`); + } finally { + clearTimeout(id); + } + } + + throw new Error("Exceeded maximum retries"); + } + + private getUrl(model: string): string { + return `${API_ROOT}/${model}`; + } +} diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts index e2cc200689319aad904e5da68a854d1c5eb4179f..e4f7c5bf17f2d1890b1f7894d81d923738eceba4 100644 --- a/packages/core/src/embeddings/index.ts +++ b/packages/core/src/embeddings/index.ts @@ -1,3 +1,4 @@ +export { DeepInfraEmbedding } from "./DeepInfraEmbedding.js"; export * from "./GeminiEmbedding.js"; export * from "./JinaAIEmbedding.js"; export * from "./MistralAIEmbedding.js";