diff --git a/.changeset/spicy-dingos-perform.md b/.changeset/spicy-dingos-perform.md new file mode 100644 index 0000000000000000000000000000000000000000..a434999d9aaa53e0c2214011cde34931078c0351 --- /dev/null +++ b/.changeset/spicy-dingos-perform.md @@ -0,0 +1,6 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +--- + +fix: export embeddings utils diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts index 5e96ef5b11fdda5aceb133d75288069d06f0a73b..b0880b8f2b9f09089b4732a85373c89bb5fb397d 100644 --- a/packages/core/src/embeddings/index.ts +++ b/packages/core/src/embeddings/index.ts @@ -2,4 +2,10 @@ export { BaseEmbedding, batchEmbeddings } from "./base"; export type { BaseEmbeddingOptions, EmbeddingInfo } from "./base"; export { MultiModalEmbedding } from "./muti-model"; export { truncateMaxTokens } from "./tokenizer"; -export { DEFAULT_SIMILARITY_TOP_K, SimilarityType, similarity } from "./utils"; +export { + DEFAULT_SIMILARITY_TOP_K, + SimilarityType, + getTopKEmbeddings, + getTopKMMREmbeddings, + similarity, +} from "./utils"; diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts index 0e292d2556de76ef0ec1eafa657fac0d3932c791..dfe456758dd79b813e0f9d27bdd06a9a9126a5fd 100644 --- a/packages/core/src/embeddings/utils.ts +++ b/packages/core/src/embeddings/utils.ts @@ -62,3 +62,121 @@ export function similarity( throw new Error("Not implemented yet"); } } + +/** + * Get the top K embeddings from a list of embeddings ordered by similarity to the query. + * @param queryEmbedding + * @param embeddings list of embeddings to consider + * @param similarityTopK max number of embeddings to return, default 2 + * @param embeddingIds ids of embeddings in the embeddings list + * @param similarityCutoff minimum similarity score + * @returns + */ +// eslint-disable-next-line max-params +export function getTopKEmbeddings( + queryEmbedding: number[], + embeddings: number[][], + similarityTopK: number = 2, + embeddingIds: any[] | null = null, + similarityCutoff: number | null = null, +): [number[], any[]] { + if (embeddingIds == null) { + embeddingIds = Array(embeddings.length).map((_, i) => i); + } + + if (embeddingIds.length !== embeddings.length) { + throw new Error( + "getTopKEmbeddings: embeddings and embeddingIds length mismatch", + ); + } + + const similarities: { similarity: number; id: number }[] = []; + + for (let i = 0; i < embeddings.length; i++) { + const sim = similarity(queryEmbedding, embeddings[i]!); + if (similarityCutoff == null || sim > similarityCutoff) { + similarities.push({ similarity: sim, id: embeddingIds[i] }); + } + } + + similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort + + const resultSimilarities: number[] = []; + const resultIds: any[] = []; + + for (let i = 0; i < similarityTopK; i++) { + if (i >= similarities.length) { + break; + } + resultSimilarities.push(similarities[i]!.similarity); + resultIds.push(similarities[i]!.id); + } + + return [resultSimilarities, resultIds]; +} + +// eslint-disable-next-line max-params +export function getTopKMMREmbeddings( + queryEmbedding: number[], + embeddings: number[][], + similarityFn: ((...args: any[]) => number) | null = null, + similarityTopK: number | null = null, + embeddingIds: any[] | null = null, + _similarityCutoff: number | null = null, + mmrThreshold: number | null = null, +): [number[], any[]] { + const threshold = mmrThreshold || 0.5; + similarityFn = similarityFn || similarity; + + if (embeddingIds === null || embeddingIds.length === 0) { + embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i); + } + const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i])); + const embedMap = new Map(fullEmbedMap); + const embedSimilarity: Map<any, number> = new Map(); + let score: number = Number.NEGATIVE_INFINITY; + let highScoreId: any | null = null; + + for (let i = 0; i < embeddings.length; i++) { + const emb = embeddings[i]; + const similarity = similarityFn(queryEmbedding, emb); + embedSimilarity.set(embeddingIds[i], similarity); + if (similarity * threshold > score) { + highScoreId = embeddingIds[i]; + score = similarity * threshold; + } + } + + const results: [number, any][] = []; + + const embeddingLength = embeddings.length; + const similarityTopKCount = similarityTopK || embeddingLength; + + while (results.length < Math.min(similarityTopKCount, embeddingLength)) { + results.push([score, highScoreId]); + embedMap.delete(highScoreId); + const recentEmbeddingId = highScoreId; + score = Number.NEGATIVE_INFINITY; + for (const embedId of Array.from(embedMap.keys())) { + const overlapWithRecent = similarityFn( + embeddings[embedMap.get(embedId)!], + embeddings[fullEmbedMap.get(recentEmbeddingId)!], + ); + if ( + threshold * embedSimilarity.get(embedId)! - + (1 - threshold) * overlapWithRecent > + score + ) { + score = + threshold * embedSimilarity.get(embedId)! - + (1 - threshold) * overlapWithRecent; + highScoreId = embedId; + } + } + } + + const resultSimilarities = results.map(([s, _]) => s); + const resultIds = results.map(([_, n]) => n); + + return [resultSimilarities, resultIds]; +} diff --git a/packages/llamaindex/src/internal/utils.ts b/packages/llamaindex/src/internal/utils.ts index 53e5b975526b21ddcaf1281374b37ae7b3f9ea56..137141838f46b50d5d08c27aa16dcb362e0e7abd 100644 --- a/packages/llamaindex/src/internal/utils.ts +++ b/packages/llamaindex/src/internal/utils.ts @@ -1,4 +1,3 @@ -import { similarity } from "@llamaindex/core/embeddings"; import type { ImageType } from "@llamaindex/core/schema"; import { fs } from "@llamaindex/env"; import { filetypemime } from "magic-bytes.js"; @@ -17,124 +16,6 @@ export const isIterable = (obj: unknown): obj is Iterable<unknown> => { return obj != null && typeof obj === "object" && Symbol.iterator in obj; }; -/** - * Get the top K embeddings from a list of embeddings ordered by similarity to the query. - * @param queryEmbedding - * @param embeddings list of embeddings to consider - * @param similarityTopK max number of embeddings to return, default 2 - * @param embeddingIds ids of embeddings in the embeddings list - * @param similarityCutoff minimum similarity score - * @returns - */ -// eslint-disable-next-line max-params -export function getTopKEmbeddings( - queryEmbedding: number[], - embeddings: number[][], - similarityTopK: number = 2, - embeddingIds: any[] | null = null, - similarityCutoff: number | null = null, -): [number[], any[]] { - if (embeddingIds == null) { - embeddingIds = Array(embeddings.length).map((_, i) => i); - } - - if (embeddingIds.length !== embeddings.length) { - throw new Error( - "getTopKEmbeddings: embeddings and embeddingIds length mismatch", - ); - } - - const similarities: { similarity: number; id: number }[] = []; - - for (let i = 0; i < embeddings.length; i++) { - const sim = similarity(queryEmbedding, embeddings[i]!); - if (similarityCutoff == null || sim > similarityCutoff) { - similarities.push({ similarity: sim, id: embeddingIds[i] }); - } - } - - similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort - - const resultSimilarities: number[] = []; - const resultIds: any[] = []; - - for (let i = 0; i < similarityTopK; i++) { - if (i >= similarities.length) { - break; - } - resultSimilarities.push(similarities[i]!.similarity); - resultIds.push(similarities[i]!.id); - } - - return [resultSimilarities, resultIds]; -} - -// eslint-disable-next-line max-params -export function getTopKMMREmbeddings( - queryEmbedding: number[], - embeddings: number[][], - similarityFn: ((...args: any[]) => number) | null = null, - similarityTopK: number | null = null, - embeddingIds: any[] | null = null, - _similarityCutoff: number | null = null, - mmrThreshold: number | null = null, -): [number[], any[]] { - const threshold = mmrThreshold || 0.5; - similarityFn = similarityFn || similarity; - - if (embeddingIds === null || embeddingIds.length === 0) { - embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i); - } - const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i])); - const embedMap = new Map(fullEmbedMap); - const embedSimilarity: Map<any, number> = new Map(); - let score: number = Number.NEGATIVE_INFINITY; - let highScoreId: any | null = null; - - for (let i = 0; i < embeddings.length; i++) { - const emb = embeddings[i]; - const similarity = similarityFn(queryEmbedding, emb); - embedSimilarity.set(embeddingIds[i], similarity); - if (similarity * threshold > score) { - highScoreId = embeddingIds[i]; - score = similarity * threshold; - } - } - - const results: [number, any][] = []; - - const embeddingLength = embeddings.length; - const similarityTopKCount = similarityTopK || embeddingLength; - - while (results.length < Math.min(similarityTopKCount, embeddingLength)) { - results.push([score, highScoreId]); - embedMap.delete(highScoreId); - const recentEmbeddingId = highScoreId; - score = Number.NEGATIVE_INFINITY; - for (const embedId of Array.from(embedMap.keys())) { - const overlapWithRecent = similarityFn( - embeddings[embedMap.get(embedId)!], - embeddings[fullEmbedMap.get(recentEmbeddingId)!], - ); - if ( - threshold * embedSimilarity.get(embedId)! - - (1 - threshold) * overlapWithRecent > - score - ) { - score = - threshold * embedSimilarity.get(embedId)! - - (1 - threshold) * overlapWithRecent; - highScoreId = embedId; - } - } - } - - const resultSimilarities = results.map(([s, _]) => s); - const resultIds = results.map(([_, n]) => n); - - return [resultSimilarities, resultIds]; -} - async function blobToDataUrl(input: Blob) { const buffer = Buffer.from(await input.arrayBuffer()); const mimes = filetypemime(buffer); diff --git a/packages/llamaindex/src/vector-store/SimpleVectorStore.ts b/packages/llamaindex/src/vector-store/SimpleVectorStore.ts index 4bd97bd99d2e942f3de05f7f4db630e426722810..a0c7f32676312ead5d87be84f85acfeddb10dc57 100644 --- a/packages/llamaindex/src/vector-store/SimpleVectorStore.ts +++ b/packages/llamaindex/src/vector-store/SimpleVectorStore.ts @@ -1,8 +1,11 @@ import type { BaseEmbedding } from "@llamaindex/core/embeddings"; +import { + getTopKEmbeddings, + getTopKMMREmbeddings, +} from "@llamaindex/core/embeddings"; import { DEFAULT_PERSIST_DIR } from "@llamaindex/core/global"; import type { BaseNode } from "@llamaindex/core/schema"; import { fs, path } from "@llamaindex/env"; -import { getTopKEmbeddings, getTopKMMREmbeddings } from "../internal/utils.js"; import { exists } from "../storage/FileSystem.js"; import { BaseVectorStore,