From 50b7d1b7bbb94c6ceb00f31a1d0cebd5f3d70d81 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Tue, 16 Jul 2024 10:49:03 -0700 Subject: [PATCH] refactor: put embedding into core (#1041) --- packages/core/package.json | 14 + .../types.ts => core/src/embeddings/base.ts} | 38 +-- packages/core/src/embeddings/index.ts | 4 + .../src/embeddings/tokenizer.ts | 0 packages/core/src/embeddings/utils.ts | 64 +++++ packages/core/src/schema/index.ts | 1 + packages/core/src/schema/type.ts | 5 + .../tests/embeddings.test.ts} | 2 +- packages/llamaindex/src/ServiceContext.ts | 2 +- packages/llamaindex/src/Settings.ts | 2 +- .../llamaindex/src/cloud/LlamaCloudIndex.ts | 7 +- packages/llamaindex/src/cloud/config.ts | 7 +- packages/llamaindex/src/constants.ts | 1 - .../src/embeddings/DeepInfraEmbedding.ts | 8 +- .../src/embeddings/GeminiEmbedding.ts | 2 +- .../src/embeddings/HuggingFaceEmbedding.ts | 6 +- .../src/embeddings/MistralAIEmbedding.ts | 2 +- .../src/embeddings/MixedbreadAIEmbeddings.ts | 6 +- .../src/embeddings/MultiModalEmbedding.ts | 2 +- .../src/embeddings/OllamaEmbedding.ts | 2 +- .../src/embeddings/OpenAIEmbedding.ts | 8 +- packages/llamaindex/src/embeddings/index.ts | 3 +- packages/llamaindex/src/embeddings/utils.ts | 256 ------------------ packages/llamaindex/src/extractors/types.ts | 5 +- .../src/indices/vectorStore/index.ts | 6 +- .../src/ingestion/IngestionCache.ts | 7 +- .../src/ingestion/IngestionPipeline.ts | 15 +- packages/llamaindex/src/ingestion/index.ts | 1 - .../strategies/DuplicatesStrategy.ts | 5 +- .../strategies/UpsertsAndDeleteStrategy.ts | 5 +- .../ingestion/strategies/UpsertsStrategy.ts | 5 +- .../src/ingestion/strategies/index.ts | 6 +- packages/llamaindex/src/ingestion/types.ts | 5 - .../src/internal/settings/EmbedModel.ts | 2 +- packages/llamaindex/src/internal/utils.ts | 178 ++++++++++++ packages/llamaindex/src/llm/ollama.ts | 2 +- packages/llamaindex/src/nodeParsers/types.ts | 5 +- .../vectorStore/MongoDBAtlasVectorStore.ts | 2 +- .../storage/vectorStore/SimpleVectorStore.ts | 13 +- .../src/storage/vectorStore/types.ts | 2 +- packages/llamaindex/src/synthesizers/utils.ts | 2 +- .../tests/ingestion/IngestionCache.test.ts | 4 +- 42 files changed, 353 insertions(+), 359 deletions(-) rename packages/{llamaindex/src/embeddings/types.ts => core/src/embeddings/base.ts} (80%) create mode 100644 packages/core/src/embeddings/index.ts rename packages/{llamaindex => core}/src/embeddings/tokenizer.ts (100%) create mode 100644 packages/core/src/embeddings/utils.ts create mode 100644 packages/core/src/schema/type.ts rename packages/{llamaindex/tests/embeddings/tokenizer.test.ts => core/tests/embeddings.test.ts} (93%) delete mode 100644 packages/llamaindex/src/embeddings/utils.ts delete mode 100644 packages/llamaindex/src/ingestion/types.ts diff --git a/packages/core/package.json b/packages/core/package.json index a26918f3e..49bd07c92 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -32,6 +32,20 @@ "default": "./dist/decorator/index.js" } }, + "./embeddings": { + "require": { + "types": "./dist/embeddings/index.d.cts", + "default": "./dist/embeddings/index.cjs" + }, + "import": { + "types": "./dist/embeddings/index.d.ts", + "default": "./dist/embeddings/index.js" + }, + "default": { + "types": "./dist/embeddings/index.d.ts", + "default": "./dist/embeddings/index.js" + } + }, "./global": { "require": { "types": "./dist/global/index.d.cts", diff --git a/packages/llamaindex/src/embeddings/types.ts b/packages/core/src/embeddings/base.ts similarity index 80% rename from packages/llamaindex/src/embeddings/types.ts rename to packages/core/src/embeddings/base.ts index b14fdf2a5..5dd74ac66 100644 --- a/packages/llamaindex/src/embeddings/types.ts +++ b/packages/core/src/embeddings/base.ts @@ -1,9 +1,8 @@ -import type { MessageContentDetail } from "@llamaindex/core/llms"; -import type { BaseNode } from "@llamaindex/core/schema"; -import { MetadataMode } from "@llamaindex/core/schema"; -import { extractSingleText } from "@llamaindex/core/utils"; import { type Tokenizers } from "@llamaindex/env"; -import type { TransformComponent } from "../ingestion/types.js"; +import type { MessageContentDetail } from "../llms"; +import type { TransformComponent } from "../schema"; +import { BaseNode, MetadataMode } from "../schema"; +import { extractSingleText } from "../utils"; import { truncateMaxTokens } from "./tokenizer.js"; import { SimilarityType, similarity } from "./utils.js"; @@ -17,7 +16,13 @@ export type EmbeddingInfo = { tokenizer?: Tokenizers; }; -export abstract class BaseEmbedding implements TransformComponent { +export type BaseEmbeddingOptions = { + logProgress?: boolean; +}; + +export abstract class BaseEmbedding + implements TransformComponent<BaseEmbeddingOptions> +{ embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedInfo?: EmbeddingInfo; @@ -45,7 +50,7 @@ export abstract class BaseEmbedding implements TransformComponent { * Optionally override this method to retrieve multiple embeddings in a single request * @param texts */ - async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> { + getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => { const embeddings: number[][] = []; for (const text of texts) { @@ -54,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent { } return embeddings; - } + }; /** * Get embeddings for a batch of texts @@ -63,22 +68,23 @@ export abstract class BaseEmbedding implements TransformComponent { */ async getTextEmbeddingsBatch( texts: string[], - options?: { - logProgress?: boolean; - }, + options?: BaseEmbeddingOptions, ): Promise<Array<number[]>> { return await batchEmbeddings( texts, - this.getTextEmbeddings.bind(this), + this.getTextEmbeddings, this.embedBatchSize, options, ); } - async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { + async transform( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ): Promise<BaseNode[]> { const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - const embeddings = await this.getTextEmbeddingsBatch(texts, _options); + const embeddings = await this.getTextEmbeddingsBatch(texts, options); for (let i = 0; i < nodes.length; i++) { nodes[i].embedding = embeddings[i]; @@ -104,9 +110,7 @@ export async function batchEmbeddings<T>( values: T[], embedFunc: EmbedFunc<T>, chunkSize: number, - options?: { - logProgress?: boolean; - }, + options?: BaseEmbeddingOptions, ): Promise<Array<number[]>> { const resultEmbeddings: Array<number[]> = []; diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts new file mode 100644 index 000000000..7d6c4aac1 --- /dev/null +++ b/packages/core/src/embeddings/index.ts @@ -0,0 +1,4 @@ +export { BaseEmbedding, batchEmbeddings } from "./base"; +export type { BaseEmbeddingOptions, EmbeddingInfo } from "./base"; +export { truncateMaxTokens } from "./tokenizer"; +export { DEFAULT_SIMILARITY_TOP_K, SimilarityType, similarity } from "./utils"; diff --git a/packages/llamaindex/src/embeddings/tokenizer.ts b/packages/core/src/embeddings/tokenizer.ts similarity index 100% rename from packages/llamaindex/src/embeddings/tokenizer.ts rename to packages/core/src/embeddings/tokenizer.ts diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts new file mode 100644 index 000000000..c6c439815 --- /dev/null +++ b/packages/core/src/embeddings/utils.ts @@ -0,0 +1,64 @@ +export const DEFAULT_SIMILARITY_TOP_K = 2; + +/** + * Similarity type + * Default is cosine similarity. Dot product and negative Euclidean distance are also supported. + */ +export enum SimilarityType { + DEFAULT = "cosine", + DOT_PRODUCT = "dot_product", + EUCLIDEAN = "euclidean", +} + +/** + * The similarity between two embeddings. + * @param embedding1 + * @param embedding2 + * @param mode + * @returns similarity score with higher numbers meaning the two embeddings are more similar + */ + +export function similarity( + embedding1: number[], + embedding2: number[], + mode: SimilarityType = SimilarityType.DEFAULT, +): number { + if (embedding1.length !== embedding2.length) { + throw new Error("Embedding length mismatch"); + } + + // NOTE I've taken enough Kahan to know that we should probably leave the + // numeric programming to numeric programmers. The naive approach here + // will probably cause some avoidable loss of floating point precision + // ml-distance is worth watching although they currently also use the naive + // formulas + function norm(x: number[]): number { + let result = 0; + for (let i = 0; i < x.length; i++) { + result += x[i] * x[i]; + } + return Math.sqrt(result); + } + + switch (mode) { + case SimilarityType.EUCLIDEAN: { + const difference = embedding1.map((x, i) => x - embedding2[i]); + return -norm(difference); + } + case SimilarityType.DOT_PRODUCT: { + let result = 0; + for (let i = 0; i < embedding1.length; i++) { + result += embedding1[i] * embedding2[i]; + } + return result; + } + case SimilarityType.DEFAULT: { + return ( + similarity(embedding1, embedding2, SimilarityType.DOT_PRODUCT) / + (norm(embedding1) * norm(embedding2)) + ); + } + default: + throw new Error("Not implemented yet"); + } +} diff --git a/packages/core/src/schema/index.ts b/packages/core/src/schema/index.ts index 583b399fc..a5cad71e1 100644 --- a/packages/core/src/schema/index.ts +++ b/packages/core/src/schema/index.ts @@ -1,2 +1,3 @@ export * from "./node"; +export type { TransformComponent } from "./type"; export * from "./zod"; diff --git a/packages/core/src/schema/type.ts b/packages/core/src/schema/type.ts new file mode 100644 index 000000000..9a16e65ad --- /dev/null +++ b/packages/core/src/schema/type.ts @@ -0,0 +1,5 @@ +import type { BaseNode } from "./node"; + +export interface TransformComponent<Options extends Record<string, unknown>> { + transform(nodes: BaseNode[], options?: Options): Promise<BaseNode[]>; +} diff --git a/packages/llamaindex/tests/embeddings/tokenizer.test.ts b/packages/core/tests/embeddings.test.ts similarity index 93% rename from packages/llamaindex/tests/embeddings/tokenizer.test.ts rename to packages/core/tests/embeddings.test.ts index 1edf50faa..3f0a12f8c 100644 --- a/packages/llamaindex/tests/embeddings/tokenizer.test.ts +++ b/packages/core/tests/embeddings.test.ts @@ -1,6 +1,6 @@ +import { truncateMaxTokens } from "@llamaindex/core/embeddings"; import { Tokenizers, tokenizers } from "@llamaindex/env"; import { describe, expect, test } from "vitest"; -import { truncateMaxTokens } from "../../src/embeddings/tokenizer.js"; describe("truncateMaxTokens", () => { const tokenizer = tokenizers.tokenizer(Tokenizers.CL100K_BASE); diff --git a/packages/llamaindex/src/ServiceContext.ts b/packages/llamaindex/src/ServiceContext.ts index dfb3f8fa4..be0dcb394 100644 --- a/packages/llamaindex/src/ServiceContext.ts +++ b/packages/llamaindex/src/ServiceContext.ts @@ -1,7 +1,7 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { LLM } from "@llamaindex/core/llms"; import { PromptHelper } from "./PromptHelper.js"; import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js"; -import type { BaseEmbedding } from "./embeddings/types.js"; import { OpenAI } from "./llm/openai.js"; import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js"; import type { NodeParser } from "./nodeParsers/types.js"; diff --git a/packages/llamaindex/src/Settings.ts b/packages/llamaindex/src/Settings.ts index ab43fff54..4adc6dcd2 100644 --- a/packages/llamaindex/src/Settings.ts +++ b/packages/llamaindex/src/Settings.ts @@ -7,10 +7,10 @@ import { OpenAI } from "./llm/openai.js"; import { PromptHelper } from "./PromptHelper.js"; import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js"; +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { LLM } from "@llamaindex/core/llms"; import { AsyncLocalStorage, getEnv } from "@llamaindex/env"; import type { ServiceContext } from "./ServiceContext.js"; -import type { BaseEmbedding } from "./embeddings/types.js"; import { getEmbeddedModel, setEmbeddedModel, diff --git a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts index 65c0412de..a4511baf2 100644 --- a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts +++ b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts @@ -1,7 +1,6 @@ -import type { Document } from "@llamaindex/core/schema"; +import type { Document, TransformComponent } from "@llamaindex/core/schema"; import type { BaseRetriever } from "../Retriever.js"; import { RetrieverQueryEngine } from "../engines/query/RetrieverQueryEngine.js"; -import type { TransformComponent } from "../ingestion/types.js"; import type { BaseNodePostprocessor } from "../postprocessors/types.js"; import type { BaseSynthesizer } from "../synthesizers/types.js"; import type { QueryEngine } from "../types.js"; @@ -148,11 +147,11 @@ export class LlamaCloudIndex { static async fromDocuments( params: { documents: Document[]; - transformations?: TransformComponent[]; + transformations?: TransformComponent<any>[]; verbose?: boolean; } & CloudConstructorParams, ): Promise<LlamaCloudIndex> { - const defaultTransformations: TransformComponent[] = [ + const defaultTransformations: TransformComponent<any>[] = [ new SimpleNodeParser(), new OpenAIEmbedding({ apiKey: getEnv("OPENAI_API_KEY"), diff --git a/packages/llamaindex/src/cloud/config.ts b/packages/llamaindex/src/cloud/config.ts index 49623bc2a..b61212292 100644 --- a/packages/llamaindex/src/cloud/config.ts +++ b/packages/llamaindex/src/cloud/config.ts @@ -3,20 +3,19 @@ import type { PipelineCreate, PipelineType, } from "@llamaindex/cloud/api"; -import { BaseNode } from "@llamaindex/core/schema"; +import { BaseNode, type TransformComponent } from "@llamaindex/core/schema"; import { OpenAIEmbedding } from "../embeddings/OpenAIEmbedding.js"; -import type { TransformComponent } from "../ingestion/types.js"; import { SimpleNodeParser } from "../nodeParsers/SimpleNodeParser.js"; export type GetPipelineCreateParams = { pipelineName: string; pipelineType: PipelineType; - transformations?: TransformComponent[]; + transformations?: TransformComponent<any>[]; inputNodes?: BaseNode[]; }; function getTransformationConfig( - transformation: TransformComponent, + transformation: TransformComponent<any>, ): ConfiguredTransformationItem { if (transformation instanceof SimpleNodeParser) { return { diff --git a/packages/llamaindex/src/constants.ts b/packages/llamaindex/src/constants.ts index bc46fed34..004d5d3e8 100644 --- a/packages/llamaindex/src/constants.ts +++ b/packages/llamaindex/src/constants.ts @@ -4,6 +4,5 @@ export const DEFAULT_NUM_OUTPUTS = 256; export const DEFAULT_CHUNK_SIZE = 1024; export const DEFAULT_CHUNK_OVERLAP = 20; export const DEFAULT_CHUNK_OVERLAP_RATIO = 0.1; -export const DEFAULT_SIMILARITY_TOP_K = 2; export const DEFAULT_PADDING = 5; diff --git a/packages/llamaindex/src/embeddings/DeepInfraEmbedding.ts b/packages/llamaindex/src/embeddings/DeepInfraEmbedding.ts index 27b033f65..cc38a422a 100644 --- a/packages/llamaindex/src/embeddings/DeepInfraEmbedding.ts +++ b/packages/llamaindex/src/embeddings/DeepInfraEmbedding.ts @@ -1,7 +1,7 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { MessageContentDetail } from "@llamaindex/core/llms"; import { extractSingleText } from "@llamaindex/core/utils"; import { getEnv } from "@llamaindex/env"; -import { BaseEmbedding } from "./types.js"; const DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32"; @@ -103,10 +103,10 @@ export class DeepInfraEmbedding extends BaseEmbedding { } } - async getTextEmbeddings(texts: string[]): Promise<number[][]> { + getTextEmbeddings = async (texts: string[]): Promise<number[][]> => { const textsWithPrefix = mapPrefixWithInputs(this.textPrefix, texts); - return await this.getDeepInfraEmbedding(textsWithPrefix); - } + return this.getDeepInfraEmbedding(textsWithPrefix); + }; async getQueryEmbeddings(queries: string[]): Promise<number[][]> { const queriesWithPrefix = mapPrefixWithInputs(this.queryPrefix, queries); diff --git a/packages/llamaindex/src/embeddings/GeminiEmbedding.ts b/packages/llamaindex/src/embeddings/GeminiEmbedding.ts index f08fe0619..e493d36d4 100644 --- a/packages/llamaindex/src/embeddings/GeminiEmbedding.ts +++ b/packages/llamaindex/src/embeddings/GeminiEmbedding.ts @@ -1,6 +1,6 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import { GeminiSession, GeminiSessionStore } from "../llm/gemini/base.js"; import { GEMINI_BACKENDS } from "../llm/gemini/types.js"; -import { BaseEmbedding } from "./types.js"; export enum GEMINI_EMBEDDING_MODEL { EMBEDDING_001 = "embedding-001", diff --git a/packages/llamaindex/src/embeddings/HuggingFaceEmbedding.ts b/packages/llamaindex/src/embeddings/HuggingFaceEmbedding.ts index 59771d3d3..4139dba14 100644 --- a/packages/llamaindex/src/embeddings/HuggingFaceEmbedding.ts +++ b/packages/llamaindex/src/embeddings/HuggingFaceEmbedding.ts @@ -1,6 +1,6 @@ import { HfInference } from "@huggingface/inference"; +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import { lazyLoadTransformers } from "../internal/deps/transformers.js"; -import { BaseEmbedding } from "./types.js"; export enum HuggingFaceEmbeddingModelType { XENOVA_ALL_MINILM_L6_V2 = "Xenova/all-MiniLM-L6-v2", @@ -91,11 +91,11 @@ export class HuggingFaceInferenceAPIEmbedding extends BaseEmbedding { return res as number[]; } - async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> { + getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => { const res = await this.hf.featureExtraction({ model: this.model, inputs: texts, }); return res as number[][]; - } + }; } diff --git a/packages/llamaindex/src/embeddings/MistralAIEmbedding.ts b/packages/llamaindex/src/embeddings/MistralAIEmbedding.ts index e1fb3f5ca..49ac9a979 100644 --- a/packages/llamaindex/src/embeddings/MistralAIEmbedding.ts +++ b/packages/llamaindex/src/embeddings/MistralAIEmbedding.ts @@ -1,5 +1,5 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import { MistralAISession } from "../llm/mistral.js"; -import { BaseEmbedding } from "./types.js"; export enum MistralAIEmbeddingModelType { MISTRAL_EMBED = "mistral-embed", diff --git a/packages/llamaindex/src/embeddings/MixedbreadAIEmbeddings.ts b/packages/llamaindex/src/embeddings/MixedbreadAIEmbeddings.ts index 0bd5d781c..45d5e1955 100644 --- a/packages/llamaindex/src/embeddings/MixedbreadAIEmbeddings.ts +++ b/packages/llamaindex/src/embeddings/MixedbreadAIEmbeddings.ts @@ -1,6 +1,6 @@ +import { BaseEmbedding, type EmbeddingInfo } from "@llamaindex/core/embeddings"; import { getEnv } from "@llamaindex/env"; import { MixedbreadAI, MixedbreadAIClient } from "@mixedbread-ai/sdk"; -import { BaseEmbedding, type EmbeddingInfo } from "./types.js"; type EmbeddingsRequestWithoutInput = Omit< MixedbreadAI.EmbeddingsRequest, @@ -153,7 +153,7 @@ export class MixedbreadAIEmbeddings extends BaseEmbedding { * const result = await mxbai.getTextEmbeddings(texts); * console.log(result); */ - async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> { + getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => { if (texts.length === 0) { return []; } @@ -166,5 +166,5 @@ export class MixedbreadAIEmbeddings extends BaseEmbedding { this.requestOptions, ); return response.data.map((d) => d.embedding as number[]); - } + }; } diff --git a/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts b/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts index bd01ccc4d..1792c7dbd 100644 --- a/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts +++ b/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts @@ -1,3 +1,4 @@ +import { BaseEmbedding, batchEmbeddings } from "@llamaindex/core/embeddings"; import type { MessageContentDetail } from "@llamaindex/core/llms"; import { ImageNode, @@ -8,7 +9,6 @@ import { type ImageType, } from "@llamaindex/core/schema"; import { extractImage, extractSingleText } from "@llamaindex/core/utils"; -import { BaseEmbedding, batchEmbeddings } from "./types.js"; /* * Base class for Multi Modal embeddings. diff --git a/packages/llamaindex/src/embeddings/OllamaEmbedding.ts b/packages/llamaindex/src/embeddings/OllamaEmbedding.ts index 5ce8f44c6..f6323c149 100644 --- a/packages/llamaindex/src/embeddings/OllamaEmbedding.ts +++ b/packages/llamaindex/src/embeddings/OllamaEmbedding.ts @@ -1,5 +1,5 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import { Ollama } from "../llm/ollama.js"; -import type { BaseEmbedding } from "./types.js"; /** * OllamaEmbedding is an alias for Ollama that implements the BaseEmbedding interface. diff --git a/packages/llamaindex/src/embeddings/OpenAIEmbedding.ts b/packages/llamaindex/src/embeddings/OpenAIEmbedding.ts index 03d4eadde..2fb3c3b30 100644 --- a/packages/llamaindex/src/embeddings/OpenAIEmbedding.ts +++ b/packages/llamaindex/src/embeddings/OpenAIEmbedding.ts @@ -1,3 +1,4 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import { Tokenizers } from "@llamaindex/env"; import type { ClientOptions as OpenAIClientOptions } from "openai"; import type { AzureOpenAIConfig } from "../llm/azure.js"; @@ -8,7 +9,6 @@ import { } from "../llm/azure.js"; import type { OpenAISession } from "../llm/openai.js"; import { getOpenAISession } from "../llm/openai.js"; -import { BaseEmbedding } from "./types.js"; export const ALL_OPENAI_EMBEDDING_MODELS = { "text-embedding-ada-002": { @@ -132,9 +132,9 @@ export class OpenAIEmbedding extends BaseEmbedding { * Get embeddings for a batch of texts * @param texts */ - async getTextEmbeddings(texts: string[]): Promise<number[][]> { - return await this.getOpenAIEmbedding(texts); - } + getTextEmbeddings = async (texts: string[]): Promise<number[][]> => { + return this.getOpenAIEmbedding(texts); + }; /** * Get embeddings for a single text diff --git a/packages/llamaindex/src/embeddings/index.ts b/packages/llamaindex/src/embeddings/index.ts index 39da4044e..7d23905f1 100644 --- a/packages/llamaindex/src/embeddings/index.ts +++ b/packages/llamaindex/src/embeddings/index.ts @@ -1,3 +1,4 @@ +export * from "@llamaindex/core/embeddings"; export { DeepInfraEmbedding } from "./DeepInfraEmbedding.js"; export { FireworksEmbedding } from "./fireworks.js"; export * from "./GeminiEmbedding.js"; @@ -9,5 +10,3 @@ export * from "./MultiModalEmbedding.js"; export { OllamaEmbedding } from "./OllamaEmbedding.js"; export * from "./OpenAIEmbedding.js"; export { TogetherEmbedding } from "./together.js"; -export * from "./types.js"; -export * from "./utils.js"; diff --git a/packages/llamaindex/src/embeddings/utils.ts b/packages/llamaindex/src/embeddings/utils.ts deleted file mode 100644 index b3a87aa75..000000000 --- a/packages/llamaindex/src/embeddings/utils.ts +++ /dev/null @@ -1,256 +0,0 @@ -import type { ImageType } from "@llamaindex/core/schema"; -import { fs } from "@llamaindex/env"; -import _ from "lodash"; -import { filetypemime } from "magic-bytes.js"; -import { DEFAULT_SIMILARITY_TOP_K } from "../constants.js"; -import type { VectorStoreQueryMode } from "../storage/vectorStore/types.js"; - -/** - * Similarity type - * Default is cosine similarity. Dot product and negative Euclidean distance are also supported. - */ -export enum SimilarityType { - DEFAULT = "cosine", - DOT_PRODUCT = "dot_product", - EUCLIDEAN = "euclidean", -} - -/** - * The similarity between two embeddings. - * @param embedding1 - * @param embedding2 - * @param mode - * @returns similarity score with higher numbers meaning the two embeddings are more similar - */ - -export function similarity( - embedding1: number[], - embedding2: number[], - mode: SimilarityType = SimilarityType.DEFAULT, -): number { - if (embedding1.length !== embedding2.length) { - throw new Error("Embedding length mismatch"); - } - - // NOTE I've taken enough Kahan to know that we should probably leave the - // numeric programming to numeric programmers. The naive approach here - // will probably cause some avoidable loss of floating point precision - // ml-distance is worth watching although they currently also use the naive - // formulas - function norm(x: number[]): number { - let result = 0; - for (let i = 0; i < x.length; i++) { - result += x[i] * x[i]; - } - return Math.sqrt(result); - } - - switch (mode) { - case SimilarityType.EUCLIDEAN: { - const difference = embedding1.map((x, i) => x - embedding2[i]); - return -norm(difference); - } - case SimilarityType.DOT_PRODUCT: { - let result = 0; - for (let i = 0; i < embedding1.length; i++) { - result += embedding1[i] * embedding2[i]; - } - return result; - } - case SimilarityType.DEFAULT: { - return ( - similarity(embedding1, embedding2, SimilarityType.DOT_PRODUCT) / - (norm(embedding1) * norm(embedding2)) - ); - } - default: - throw new Error("Not implemented yet"); - } -} - -/** - * Get the top K embeddings from a list of embeddings ordered by similarity to the query. - * @param queryEmbedding - * @param embeddings list of embeddings to consider - * @param similarityTopK max number of embeddings to return, default 2 - * @param embeddingIds ids of embeddings in the embeddings list - * @param similarityCutoff minimum similarity score - * @returns - */ -// eslint-disable-next-line max-params -export function getTopKEmbeddings( - queryEmbedding: number[], - embeddings: number[][], - similarityTopK: number = DEFAULT_SIMILARITY_TOP_K, - embeddingIds: any[] | null = null, - similarityCutoff: number | null = null, -): [number[], any[]] { - if (embeddingIds == null) { - embeddingIds = Array(embeddings.length).map((_, i) => i); - } - - if (embeddingIds.length !== embeddings.length) { - throw new Error( - "getTopKEmbeddings: embeddings and embeddingIds length mismatch", - ); - } - - const similarities: { similarity: number; id: number }[] = []; - - for (let i = 0; i < embeddings.length; i++) { - const sim = similarity(queryEmbedding, embeddings[i]); - if (similarityCutoff == null || sim > similarityCutoff) { - similarities.push({ similarity: sim, id: embeddingIds[i] }); - } - } - - similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort - - const resultSimilarities: number[] = []; - const resultIds: any[] = []; - - for (let i = 0; i < similarityTopK; i++) { - if (i >= similarities.length) { - break; - } - resultSimilarities.push(similarities[i].similarity); - resultIds.push(similarities[i].id); - } - - return [resultSimilarities, resultIds]; -} - -// eslint-disable-next-line max-params -export function getTopKEmbeddingsLearner( - queryEmbedding: number[], - embeddings: number[][], - similarityTopK?: number, - embeddingsIds?: any[], - queryMode?: VectorStoreQueryMode, -): [number[], any[]] { - throw new Error("Not implemented yet"); -} - -// eslint-disable-next-line max-params -export function getTopKMMREmbeddings( - queryEmbedding: number[], - embeddings: number[][], - similarityFn: ((...args: any[]) => number) | null = null, - similarityTopK: number | null = null, - embeddingIds: any[] | null = null, - _similarityCutoff: number | null = null, - mmrThreshold: number | null = null, -): [number[], any[]] { - const threshold = mmrThreshold || 0.5; - similarityFn = similarityFn || similarity; - - if (embeddingIds === null || embeddingIds.length === 0) { - embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i); - } - const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i])); - const embedMap = new Map(fullEmbedMap); - const embedSimilarity: Map<any, number> = new Map(); - let score: number = Number.NEGATIVE_INFINITY; - let highScoreId: any | null = null; - - for (let i = 0; i < embeddings.length; i++) { - const emb = embeddings[i]; - const similarity = similarityFn(queryEmbedding, emb); - embedSimilarity.set(embeddingIds[i], similarity); - if (similarity * threshold > score) { - highScoreId = embeddingIds[i]; - score = similarity * threshold; - } - } - - const results: [number, any][] = []; - - const embeddingLength = embeddings.length; - const similarityTopKCount = similarityTopK || embeddingLength; - - while (results.length < Math.min(similarityTopKCount, embeddingLength)) { - results.push([score, highScoreId]); - embedMap.delete(highScoreId); - const recentEmbeddingId = highScoreId; - score = Number.NEGATIVE_INFINITY; - for (const embedId of Array.from(embedMap.keys())) { - const overlapWithRecent = similarityFn( - embeddings[embedMap.get(embedId)!], - embeddings[fullEmbedMap.get(recentEmbeddingId)!], - ); - if ( - threshold * embedSimilarity.get(embedId)! - - (1 - threshold) * overlapWithRecent > - score - ) { - score = - threshold * embedSimilarity.get(embedId)! - - (1 - threshold) * overlapWithRecent; - highScoreId = embedId; - } - } - } - - const resultSimilarities = results.map(([s, _]) => s); - const resultIds = results.map(([_, n]) => n); - - return [resultSimilarities, resultIds]; -} - -async function blobToDataUrl(input: Blob) { - const buffer = Buffer.from(await input.arrayBuffer()); - const mimes = filetypemime(buffer); - if (mimes.length < 1) { - throw new Error("Unsupported image type"); - } - return "data:" + mimes[0] + ";base64," + buffer.toString("base64"); -} - -export async function imageToString(input: ImageType): Promise<string> { - if (input instanceof Blob) { - // if the image is a Blob, convert it to a base64 data URL - return await blobToDataUrl(input); - } else if (_.isString(input)) { - return input; - } else if (input instanceof URL) { - return input.toString(); - } else { - throw new Error(`Unsupported input type: ${typeof input}`); - } -} - -export function stringToImage(input: string): ImageType { - if (input.startsWith("data:")) { - // if the input is a base64 data URL, convert it back to a Blob - const base64Data = input.split(",")[1]; - const byteArray = Buffer.from(base64Data, "base64"); - return new Blob([byteArray]); - } else if (input.startsWith("http://") || input.startsWith("https://")) { - return new URL(input); - } else if (_.isString(input)) { - return input; - } else { - throw new Error(`Unsupported input type: ${typeof input}`); - } -} - -export async function imageToDataUrl(input: ImageType): Promise<string> { - // first ensure, that the input is a Blob - if ( - (input instanceof URL && input.protocol === "file:") || - _.isString(input) - ) { - // string or file URL - const dataBuffer = await fs.readFile( - input instanceof URL ? input.pathname : input, - ); - input = new Blob([dataBuffer]); - } else if (!(input instanceof Blob)) { - if (input instanceof URL) { - throw new Error(`Unsupported URL with protocol: ${input.protocol}`); - } else { - throw new Error(`Unsupported input type: ${typeof input}`); - } - } - return await blobToDataUrl(input); -} diff --git a/packages/llamaindex/src/extractors/types.ts b/packages/llamaindex/src/extractors/types.ts index 3e8f169da..7b5063f23 100644 --- a/packages/llamaindex/src/extractors/types.ts +++ b/packages/llamaindex/src/extractors/types.ts @@ -1,12 +1,11 @@ -import type { BaseNode } from "@llamaindex/core/schema"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import { MetadataMode, TextNode } from "@llamaindex/core/schema"; -import type { TransformComponent } from "../ingestion/types.js"; import { defaultNodeTextTemplate } from "./prompts.js"; /* * Abstract class for all extractors. */ -export abstract class BaseExtractor implements TransformComponent { +export abstract class BaseExtractor implements TransformComponent<any> { isTextNodeOnly: boolean = true; showProgress: boolean = true; metadataMode: MetadataMode = MetadataMode.ALL; diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index aab886649..ff79ce8df 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -1,3 +1,7 @@ +import { + DEFAULT_SIMILARITY_TOP_K, + type BaseEmbedding, +} from "@llamaindex/core/embeddings"; import { Settings } from "@llamaindex/core/global"; import type { MessageContent } from "@llamaindex/core/llms"; import { @@ -13,8 +17,6 @@ import { wrapEventCaller } from "@llamaindex/core/utils"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { nodeParserFromSettingsOrContext } from "../../Settings.js"; -import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js"; -import type { BaseEmbedding } from "../../embeddings/index.js"; import { RetrieverQueryEngine } from "../../engines/query/RetrieverQueryEngine.js"; import { addNodesToVectorStores, diff --git a/packages/llamaindex/src/ingestion/IngestionCache.ts b/packages/llamaindex/src/ingestion/IngestionCache.ts index e64dee519..05adcd25c 100644 --- a/packages/llamaindex/src/ingestion/IngestionCache.ts +++ b/packages/llamaindex/src/ingestion/IngestionCache.ts @@ -1,12 +1,11 @@ -import type { BaseNode } from "@llamaindex/core/schema"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import { MetadataMode } from "@llamaindex/core/schema"; import { createSHA256 } from "@llamaindex/env"; import { docToJson, jsonToDoc } from "../storage/docStore/utils.js"; import { SimpleKVStore } from "../storage/kvStore/SimpleKVStore.js"; import type { BaseKVStore } from "../storage/kvStore/types.js"; -import type { TransformComponent } from "./types.js"; -const transformToJSON = (obj: TransformComponent) => { +const transformToJSON = (obj: TransformComponent<any>) => { const seen: any[] = []; const replacer = (key: string, value: any) => { @@ -27,7 +26,7 @@ const transformToJSON = (obj: TransformComponent) => { export function getTransformationHash( nodes: BaseNode[], - transform: TransformComponent, + transform: TransformComponent<any>, ) { const nodesStr: string = nodes .map((node) => node.getContent(MetadataMode.ALL)) diff --git a/packages/llamaindex/src/ingestion/IngestionPipeline.ts b/packages/llamaindex/src/ingestion/IngestionPipeline.ts index edf4075f7..e2191a687 100644 --- a/packages/llamaindex/src/ingestion/IngestionPipeline.ts +++ b/packages/llamaindex/src/ingestion/IngestionPipeline.ts @@ -1,3 +1,4 @@ +import type { TransformComponent } from "@llamaindex/core/schema"; import { ModalityType, splitNodesByType, @@ -16,7 +17,6 @@ import { DocStoreStrategy, createDocStoreStrategy, } from "./strategies/index.js"; -import type { TransformComponent } from "./types.js"; type IngestionRunArgs = { documents?: Document[]; @@ -26,12 +26,12 @@ type IngestionRunArgs = { type TransformRunArgs = { inPlace?: boolean; cache?: IngestionCache; - docStoreStrategy?: TransformComponent; + docStoreStrategy?: TransformComponent<any>; }; export async function runTransformations( nodesToRun: BaseNode[], - transformations: TransformComponent[], + transformations: TransformComponent<any>[], transformOptions: any = {}, { inPlace = true, cache, docStoreStrategy }: TransformRunArgs = {}, ): Promise<BaseNode[]> { @@ -60,7 +60,7 @@ export async function runTransformations( } export class IngestionPipeline { - transformations: TransformComponent[] = []; + transformations: TransformComponent<any>[] = []; documents?: Document[]; reader?: BaseReader; vectorStore?: VectorStore; @@ -70,7 +70,7 @@ export class IngestionPipeline { cache?: IngestionCache; disableCache: boolean = false; - private _docStoreStrategy?: TransformComponent; + private _docStoreStrategy?: TransformComponent<any>; constructor(init?: Partial<IngestionPipeline>) { Object.assign(this, init); @@ -112,10 +112,7 @@ export class IngestionPipeline { return inputNodes.flat(); } - async run( - args: IngestionRunArgs & TransformRunArgs = {}, - transformOptions?: any, - ): Promise<BaseNode[]> { + async run(args: any = {}, transformOptions?: any): Promise<BaseNode[]> { args.cache = args.cache ?? this.cache; args.docStoreStrategy = args.docStoreStrategy ?? this._docStoreStrategy; const inputNodes = await this.prepareInput(args.documents, args.nodes); diff --git a/packages/llamaindex/src/ingestion/index.ts b/packages/llamaindex/src/ingestion/index.ts index 4234d68d9..cda77ff66 100644 --- a/packages/llamaindex/src/ingestion/index.ts +++ b/packages/llamaindex/src/ingestion/index.ts @@ -1,2 +1 @@ export * from "./IngestionPipeline.js"; -export * from "./types.js"; diff --git a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts index a6dc484b9..679256a53 100644 --- a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts @@ -1,11 +1,10 @@ -import type { BaseNode } from "@llamaindex/core/schema"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; -import type { TransformComponent } from "../types.js"; /** * Handle doc store duplicates by checking all hashes. */ -export class DuplicatesStrategy implements TransformComponent { +export class DuplicatesStrategy implements TransformComponent<any> { private docStore: BaseDocumentStore; constructor(docStore: BaseDocumentStore) { diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts index 7225794a4..700c23f00 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts @@ -1,14 +1,13 @@ -import type { BaseNode } from "@llamaindex/core/schema"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; -import type { TransformComponent } from "../types.js"; import { classify } from "./classify.js"; /** * Handle docstore upserts by checking hashes and ids. * Identify missing docs and delete them from docstore and vector store */ -export class UpsertsAndDeleteStrategy implements TransformComponent { +export class UpsertsAndDeleteStrategy implements TransformComponent<any> { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts index 83f7584e9..cc30716a1 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts @@ -1,13 +1,12 @@ -import type { BaseNode } from "@llamaindex/core/schema"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; -import type { TransformComponent } from "../types.js"; import { classify } from "./classify.js"; /** * Handles doc store upserts by checking hashes and ids. */ -export class UpsertsStrategy implements TransformComponent { +export class UpsertsStrategy implements TransformComponent<any> { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; diff --git a/packages/llamaindex/src/ingestion/strategies/index.ts b/packages/llamaindex/src/ingestion/strategies/index.ts index 13d916c92..6e2c7ecbe 100644 --- a/packages/llamaindex/src/ingestion/strategies/index.ts +++ b/packages/llamaindex/src/ingestion/strategies/index.ts @@ -1,6 +1,6 @@ +import type { TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; -import type { TransformComponent } from "../types.js"; import { DuplicatesStrategy } from "./DuplicatesStrategy.js"; import { UpsertsAndDeleteStrategy } from "./UpsertsAndDeleteStrategy.js"; import { UpsertsStrategy } from "./UpsertsStrategy.js"; @@ -19,7 +19,7 @@ export enum DocStoreStrategy { NONE = "none", // no-op strategy } -class NoOpStrategy implements TransformComponent { +class NoOpStrategy implements TransformComponent<any> { async transform(nodes: any[]): Promise<any[]> { return nodes; } @@ -29,7 +29,7 @@ export function createDocStoreStrategy( docStoreStrategy: DocStoreStrategy, docStore?: BaseDocumentStore, vectorStores: VectorStore[] = [], -): TransformComponent { +): TransformComponent<any> { if (docStoreStrategy === DocStoreStrategy.NONE) { return new NoOpStrategy(); } diff --git a/packages/llamaindex/src/ingestion/types.ts b/packages/llamaindex/src/ingestion/types.ts deleted file mode 100644 index d61d2c480..000000000 --- a/packages/llamaindex/src/ingestion/types.ts +++ /dev/null @@ -1,5 +0,0 @@ -import type { BaseNode } from "@llamaindex/core/schema"; - -export interface TransformComponent { - transform(nodes: BaseNode[], options?: any): Promise<BaseNode[]>; -} diff --git a/packages/llamaindex/src/internal/settings/EmbedModel.ts b/packages/llamaindex/src/internal/settings/EmbedModel.ts index 064ce1b7d..fab2331a8 100644 --- a/packages/llamaindex/src/internal/settings/EmbedModel.ts +++ b/packages/llamaindex/src/internal/settings/EmbedModel.ts @@ -1,6 +1,6 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import { AsyncLocalStorage } from "@llamaindex/env"; import { OpenAIEmbedding } from "../../embeddings/OpenAIEmbedding.js"; -import type { BaseEmbedding } from "../../embeddings/index.js"; const embeddedModelAsyncLocalStorage = new AsyncLocalStorage<BaseEmbedding>(); let globalEmbeddedModel: BaseEmbedding | null = null; diff --git a/packages/llamaindex/src/internal/utils.ts b/packages/llamaindex/src/internal/utils.ts index a587290e2..a301c2707 100644 --- a/packages/llamaindex/src/internal/utils.ts +++ b/packages/llamaindex/src/internal/utils.ts @@ -1,4 +1,8 @@ +import { similarity } from "@llamaindex/core/embeddings"; import type { JSONValue } from "@llamaindex/core/global"; +import type { ImageType } from "@llamaindex/core/schema"; +import { fs } from "@llamaindex/env"; +import { filetypemime } from "magic-bytes.js"; export const isAsyncIterable = ( obj: unknown, @@ -24,3 +28,177 @@ export function prettifyError(error: unknown): string { export function stringifyJSONToMessageContent(value: JSONValue): string { return JSON.stringify(value, null, 2).replace(/"([^"]*)"/g, "$1"); } + +/** + * Get the top K embeddings from a list of embeddings ordered by similarity to the query. + * @param queryEmbedding + * @param embeddings list of embeddings to consider + * @param similarityTopK max number of embeddings to return, default 2 + * @param embeddingIds ids of embeddings in the embeddings list + * @param similarityCutoff minimum similarity score + * @returns + */ +// eslint-disable-next-line max-params +export function getTopKEmbeddings( + queryEmbedding: number[], + embeddings: number[][], + similarityTopK: number = 2, + embeddingIds: any[] | null = null, + similarityCutoff: number | null = null, +): [number[], any[]] { + if (embeddingIds == null) { + embeddingIds = Array(embeddings.length).map((_, i) => i); + } + + if (embeddingIds.length !== embeddings.length) { + throw new Error( + "getTopKEmbeddings: embeddings and embeddingIds length mismatch", + ); + } + + const similarities: { similarity: number; id: number }[] = []; + + for (let i = 0; i < embeddings.length; i++) { + const sim = similarity(queryEmbedding, embeddings[i]); + if (similarityCutoff == null || sim > similarityCutoff) { + similarities.push({ similarity: sim, id: embeddingIds[i] }); + } + } + + similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort + + const resultSimilarities: number[] = []; + const resultIds: any[] = []; + + for (let i = 0; i < similarityTopK; i++) { + if (i >= similarities.length) { + break; + } + resultSimilarities.push(similarities[i].similarity); + resultIds.push(similarities[i].id); + } + + return [resultSimilarities, resultIds]; +} + +// eslint-disable-next-line max-params +export function getTopKMMREmbeddings( + queryEmbedding: number[], + embeddings: number[][], + similarityFn: ((...args: any[]) => number) | null = null, + similarityTopK: number | null = null, + embeddingIds: any[] | null = null, + _similarityCutoff: number | null = null, + mmrThreshold: number | null = null, +): [number[], any[]] { + const threshold = mmrThreshold || 0.5; + similarityFn = similarityFn || similarity; + + if (embeddingIds === null || embeddingIds.length === 0) { + embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i); + } + const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i])); + const embedMap = new Map(fullEmbedMap); + const embedSimilarity: Map<any, number> = new Map(); + let score: number = Number.NEGATIVE_INFINITY; + let highScoreId: any | null = null; + + for (let i = 0; i < embeddings.length; i++) { + const emb = embeddings[i]; + const similarity = similarityFn(queryEmbedding, emb); + embedSimilarity.set(embeddingIds[i], similarity); + if (similarity * threshold > score) { + highScoreId = embeddingIds[i]; + score = similarity * threshold; + } + } + + const results: [number, any][] = []; + + const embeddingLength = embeddings.length; + const similarityTopKCount = similarityTopK || embeddingLength; + + while (results.length < Math.min(similarityTopKCount, embeddingLength)) { + results.push([score, highScoreId]); + embedMap.delete(highScoreId); + const recentEmbeddingId = highScoreId; + score = Number.NEGATIVE_INFINITY; + for (const embedId of Array.from(embedMap.keys())) { + const overlapWithRecent = similarityFn( + embeddings[embedMap.get(embedId)!], + embeddings[fullEmbedMap.get(recentEmbeddingId)!], + ); + if ( + threshold * embedSimilarity.get(embedId)! - + (1 - threshold) * overlapWithRecent > + score + ) { + score = + threshold * embedSimilarity.get(embedId)! - + (1 - threshold) * overlapWithRecent; + highScoreId = embedId; + } + } + } + + const resultSimilarities = results.map(([s, _]) => s); + const resultIds = results.map(([_, n]) => n); + + return [resultSimilarities, resultIds]; +} + +async function blobToDataUrl(input: Blob) { + const buffer = Buffer.from(await input.arrayBuffer()); + const mimes = filetypemime(buffer); + if (mimes.length < 1) { + throw new Error("Unsupported image type"); + } + return "data:" + mimes[0] + ";base64," + buffer.toString("base64"); +} + +export async function imageToString(input: ImageType): Promise<string> { + if (input instanceof Blob) { + // if the image is a Blob, convert it to a base64 data URL + return await blobToDataUrl(input); + } else if (typeof input === "string") { + return input; + } else if (input instanceof URL) { + return input.toString(); + } else { + throw new Error(`Unsupported input type: ${typeof input}`); + } +} + +export function stringToImage(input: string): ImageType { + if (input.startsWith("data:")) { + // if the input is a base64 data URL, convert it back to a Blob + const base64Data = input.split(",")[1]; + const byteArray = Buffer.from(base64Data, "base64"); + return new Blob([byteArray]); + } else if (input.startsWith("http://") || input.startsWith("https://")) { + return new URL(input); + } else { + return input; + } +} + +export async function imageToDataUrl(input: ImageType): Promise<string> { + // first ensure, that the input is a Blob + if ( + (input instanceof URL && input.protocol === "file:") || + typeof input === "string" + ) { + // string or file URL + const dataBuffer = await fs.readFile( + input instanceof URL ? input.pathname : input, + ); + input = new Blob([dataBuffer]); + } else if (!(input instanceof Blob)) { + if (input instanceof URL) { + throw new Error(`Unsupported URL with protocol: ${input.protocol}`); + } else { + throw new Error(`Unsupported input type: ${typeof input}`); + } + } + return await blobToDataUrl(input); +} diff --git a/packages/llamaindex/src/llm/ollama.ts b/packages/llamaindex/src/llm/ollama.ts index 1683961d2..1fea1b7d0 100644 --- a/packages/llamaindex/src/llm/ollama.ts +++ b/packages/llamaindex/src/llm/ollama.ts @@ -1,3 +1,4 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { ChatResponse, ChatResponseChunk, @@ -10,7 +11,6 @@ import type { LLMMetadata, } from "@llamaindex/core/llms"; import { extractText, streamConverter } from "@llamaindex/core/utils"; -import { BaseEmbedding } from "../embeddings/types.js"; import { Ollama as OllamaBase, type Config, diff --git a/packages/llamaindex/src/nodeParsers/types.ts b/packages/llamaindex/src/nodeParsers/types.ts index 472996bab..05828addd 100644 --- a/packages/llamaindex/src/nodeParsers/types.ts +++ b/packages/llamaindex/src/nodeParsers/types.ts @@ -1,10 +1,9 @@ -import type { BaseNode } from "@llamaindex/core/schema"; -import type { TransformComponent } from "../ingestion/types.js"; +import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; /** * A NodeParser generates Nodes from Documents */ -export interface NodeParser extends TransformComponent { +export interface NodeParser extends TransformComponent<any> { /** * Generates an array of nodes from an array of documents. * @param documents - The documents to generate nodes from. diff --git a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts index c4a7ee421..241e202f4 100644 --- a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts @@ -1,9 +1,9 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { BaseNode } from "@llamaindex/core/schema"; import { MetadataMode } from "@llamaindex/core/schema"; import { getEnv } from "@llamaindex/env"; import type { BulkWriteOptions, Collection } from "mongodb"; import { MongoClient } from "mongodb"; -import { BaseEmbedding } from "../../embeddings/types.js"; import { VectorStoreBase, type MetadataFilters, diff --git a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts index 5e8ff9377..4f927075d 100644 --- a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts @@ -1,11 +1,10 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { BaseNode } from "@llamaindex/core/schema"; import { fs, path } from "@llamaindex/env"; -import { BaseEmbedding } from "../../embeddings/index.js"; import { getTopKEmbeddings, - getTopKEmbeddingsLearner, getTopKMMREmbeddings, -} from "../../embeddings/utils.js"; +} from "../../internal/utils.js"; import { exists } from "../FileSystem.js"; import { DEFAULT_PERSIST_DIR } from "../constants.js"; import { @@ -116,11 +115,9 @@ export class SimpleVectorStore let topSimilarities: number[], topIds: string[]; if (LEARNER_MODES.has(query.mode)) { - [topSimilarities, topIds] = getTopKEmbeddingsLearner( - queryEmbedding, - embeddings, - query.similarityTopK, - nodeIds, + // fixme: unfinished + throw new Error( + "Learner modes not implemented for SimpleVectorStore yet.", ); } else if (query.mode === MMR_MODE) { const mmrThreshold = query.mmrThreshold; diff --git a/packages/llamaindex/src/storage/vectorStore/types.ts b/packages/llamaindex/src/storage/vectorStore/types.ts index 0249cb275..1862631f9 100644 --- a/packages/llamaindex/src/storage/vectorStore/types.ts +++ b/packages/llamaindex/src/storage/vectorStore/types.ts @@ -1,5 +1,5 @@ +import type { BaseEmbedding } from "@llamaindex/core/embeddings"; import type { BaseNode, ModalityType } from "@llamaindex/core/schema"; -import type { BaseEmbedding } from "../../embeddings/types.js"; import { getEmbeddedModel } from "../../internal/settings/EmbedModel.js"; export interface VectorStoreQueryResult { diff --git a/packages/llamaindex/src/synthesizers/utils.ts b/packages/llamaindex/src/synthesizers/utils.ts index 79d465ffe..489c33efd 100644 --- a/packages/llamaindex/src/synthesizers/utils.ts +++ b/packages/llamaindex/src/synthesizers/utils.ts @@ -7,7 +7,7 @@ import { type BaseNode, } from "@llamaindex/core/schema"; import type { SimplePrompt } from "../Prompt.js"; -import { imageToDataUrl } from "../embeddings/utils.js"; +import { imageToDataUrl } from "../internal/utils.js"; export async function createMessageContent( prompt: SimplePrompt, diff --git a/packages/llamaindex/tests/ingestion/IngestionCache.test.ts b/packages/llamaindex/tests/ingestion/IngestionCache.test.ts index 88fb9ff69..5af8b3e70 100644 --- a/packages/llamaindex/tests/ingestion/IngestionCache.test.ts +++ b/packages/llamaindex/tests/ingestion/IngestionCache.test.ts @@ -1,10 +1,10 @@ import type { BaseNode } from "@llamaindex/core/schema"; import { TextNode } from "@llamaindex/core/schema"; +import type { TransformComponent } from "llamaindex"; import { IngestionCache, getTransformationHash, } from "llamaindex/ingestion/IngestionCache"; -import type { TransformComponent } from "llamaindex/ingestion/index"; import { SimpleNodeParser } from "llamaindex/nodeParsers/index"; import { beforeAll, describe, expect, test } from "vitest"; @@ -28,7 +28,7 @@ describe("IngestionCache", () => { }); describe("getTransformationHash", () => { - let nodes: BaseNode[], transform: TransformComponent; + let nodes: BaseNode[], transform: TransformComponent<any>; beforeAll(() => { nodes = [new TextNode({ text: "some text", id_: "some id" })]; -- GitLab