Skip to content
Snippets Groups Projects
Unverified Commit 691c5bca authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

fix: export embeddings utils (#1387)

parent 9ab998c5
No related branches found
No related tags found
No related merge requests found
---
"@llamaindex/core": patch
"llamaindex": patch
---
fix: export embeddings utils
......@@ -2,4 +2,10 @@ export { BaseEmbedding, batchEmbeddings } from "./base";
export type { BaseEmbeddingOptions, EmbeddingInfo } from "./base";
export { MultiModalEmbedding } from "./muti-model";
export { truncateMaxTokens } from "./tokenizer";
export { DEFAULT_SIMILARITY_TOP_K, SimilarityType, similarity } from "./utils";
export {
DEFAULT_SIMILARITY_TOP_K,
SimilarityType,
getTopKEmbeddings,
getTopKMMREmbeddings,
similarity,
} from "./utils";
......@@ -62,3 +62,121 @@ export function similarity(
throw new Error("Not implemented yet");
}
}
/**
* Get the top K embeddings from a list of embeddings ordered by similarity to the query.
* @param queryEmbedding
* @param embeddings list of embeddings to consider
* @param similarityTopK max number of embeddings to return, default 2
* @param embeddingIds ids of embeddings in the embeddings list
* @param similarityCutoff minimum similarity score
* @returns
*/
// eslint-disable-next-line max-params
export function getTopKEmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK: number = 2,
embeddingIds: any[] | null = null,
similarityCutoff: number | null = null,
): [number[], any[]] {
if (embeddingIds == null) {
embeddingIds = Array(embeddings.length).map((_, i) => i);
}
if (embeddingIds.length !== embeddings.length) {
throw new Error(
"getTopKEmbeddings: embeddings and embeddingIds length mismatch",
);
}
const similarities: { similarity: number; id: number }[] = [];
for (let i = 0; i < embeddings.length; i++) {
const sim = similarity(queryEmbedding, embeddings[i]!);
if (similarityCutoff == null || sim > similarityCutoff) {
similarities.push({ similarity: sim, id: embeddingIds[i] });
}
}
similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort
const resultSimilarities: number[] = [];
const resultIds: any[] = [];
for (let i = 0; i < similarityTopK; i++) {
if (i >= similarities.length) {
break;
}
resultSimilarities.push(similarities[i]!.similarity);
resultIds.push(similarities[i]!.id);
}
return [resultSimilarities, resultIds];
}
// eslint-disable-next-line max-params
export function getTopKMMREmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn: ((...args: any[]) => number) | null = null,
similarityTopK: number | null = null,
embeddingIds: any[] | null = null,
_similarityCutoff: number | null = null,
mmrThreshold: number | null = null,
): [number[], any[]] {
const threshold = mmrThreshold || 0.5;
similarityFn = similarityFn || similarity;
if (embeddingIds === null || embeddingIds.length === 0) {
embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i);
}
const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i]));
const embedMap = new Map(fullEmbedMap);
const embedSimilarity: Map<any, number> = new Map();
let score: number = Number.NEGATIVE_INFINITY;
let highScoreId: any | null = null;
for (let i = 0; i < embeddings.length; i++) {
const emb = embeddings[i];
const similarity = similarityFn(queryEmbedding, emb);
embedSimilarity.set(embeddingIds[i], similarity);
if (similarity * threshold > score) {
highScoreId = embeddingIds[i];
score = similarity * threshold;
}
}
const results: [number, any][] = [];
const embeddingLength = embeddings.length;
const similarityTopKCount = similarityTopK || embeddingLength;
while (results.length < Math.min(similarityTopKCount, embeddingLength)) {
results.push([score, highScoreId]);
embedMap.delete(highScoreId);
const recentEmbeddingId = highScoreId;
score = Number.NEGATIVE_INFINITY;
for (const embedId of Array.from(embedMap.keys())) {
const overlapWithRecent = similarityFn(
embeddings[embedMap.get(embedId)!],
embeddings[fullEmbedMap.get(recentEmbeddingId)!],
);
if (
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent >
score
) {
score =
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent;
highScoreId = embedId;
}
}
}
const resultSimilarities = results.map(([s, _]) => s);
const resultIds = results.map(([_, n]) => n);
return [resultSimilarities, resultIds];
}
import { similarity } from "@llamaindex/core/embeddings";
import type { ImageType } from "@llamaindex/core/schema";
import { fs } from "@llamaindex/env";
import { filetypemime } from "magic-bytes.js";
......@@ -17,124 +16,6 @@ export const isIterable = (obj: unknown): obj is Iterable<unknown> => {
return obj != null && typeof obj === "object" && Symbol.iterator in obj;
};
/**
* Get the top K embeddings from a list of embeddings ordered by similarity to the query.
* @param queryEmbedding
* @param embeddings list of embeddings to consider
* @param similarityTopK max number of embeddings to return, default 2
* @param embeddingIds ids of embeddings in the embeddings list
* @param similarityCutoff minimum similarity score
* @returns
*/
// eslint-disable-next-line max-params
export function getTopKEmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK: number = 2,
embeddingIds: any[] | null = null,
similarityCutoff: number | null = null,
): [number[], any[]] {
if (embeddingIds == null) {
embeddingIds = Array(embeddings.length).map((_, i) => i);
}
if (embeddingIds.length !== embeddings.length) {
throw new Error(
"getTopKEmbeddings: embeddings and embeddingIds length mismatch",
);
}
const similarities: { similarity: number; id: number }[] = [];
for (let i = 0; i < embeddings.length; i++) {
const sim = similarity(queryEmbedding, embeddings[i]!);
if (similarityCutoff == null || sim > similarityCutoff) {
similarities.push({ similarity: sim, id: embeddingIds[i] });
}
}
similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort
const resultSimilarities: number[] = [];
const resultIds: any[] = [];
for (let i = 0; i < similarityTopK; i++) {
if (i >= similarities.length) {
break;
}
resultSimilarities.push(similarities[i]!.similarity);
resultIds.push(similarities[i]!.id);
}
return [resultSimilarities, resultIds];
}
// eslint-disable-next-line max-params
export function getTopKMMREmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn: ((...args: any[]) => number) | null = null,
similarityTopK: number | null = null,
embeddingIds: any[] | null = null,
_similarityCutoff: number | null = null,
mmrThreshold: number | null = null,
): [number[], any[]] {
const threshold = mmrThreshold || 0.5;
similarityFn = similarityFn || similarity;
if (embeddingIds === null || embeddingIds.length === 0) {
embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i);
}
const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i]));
const embedMap = new Map(fullEmbedMap);
const embedSimilarity: Map<any, number> = new Map();
let score: number = Number.NEGATIVE_INFINITY;
let highScoreId: any | null = null;
for (let i = 0; i < embeddings.length; i++) {
const emb = embeddings[i];
const similarity = similarityFn(queryEmbedding, emb);
embedSimilarity.set(embeddingIds[i], similarity);
if (similarity * threshold > score) {
highScoreId = embeddingIds[i];
score = similarity * threshold;
}
}
const results: [number, any][] = [];
const embeddingLength = embeddings.length;
const similarityTopKCount = similarityTopK || embeddingLength;
while (results.length < Math.min(similarityTopKCount, embeddingLength)) {
results.push([score, highScoreId]);
embedMap.delete(highScoreId);
const recentEmbeddingId = highScoreId;
score = Number.NEGATIVE_INFINITY;
for (const embedId of Array.from(embedMap.keys())) {
const overlapWithRecent = similarityFn(
embeddings[embedMap.get(embedId)!],
embeddings[fullEmbedMap.get(recentEmbeddingId)!],
);
if (
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent >
score
) {
score =
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent;
highScoreId = embedId;
}
}
}
const resultSimilarities = results.map(([s, _]) => s);
const resultIds = results.map(([_, n]) => n);
return [resultSimilarities, resultIds];
}
async function blobToDataUrl(input: Blob) {
const buffer = Buffer.from(await input.arrayBuffer());
const mimes = filetypemime(buffer);
......
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import {
getTopKEmbeddings,
getTopKMMREmbeddings,
} from "@llamaindex/core/embeddings";
import { DEFAULT_PERSIST_DIR } from "@llamaindex/core/global";
import type { BaseNode } from "@llamaindex/core/schema";
import { fs, path } from "@llamaindex/env";
import { getTopKEmbeddings, getTopKMMREmbeddings } from "../internal/utils.js";
import { exists } from "../storage/FileSystem.js";
import {
BaseVectorStore,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment