Skip to content
Snippets Groups Projects
Unverified Commit bbc8c878 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

fix: prefer using embedding model from vector store (#1708)

parent 4b49428f
No related branches found
No related tags found
No related merge requests found
Showing
with 236 additions and 24 deletions
export * from "@llamaindex/openai";
import { Document } from "@llamaindex/core/schema";
import { Settings } from "llamaindex";
import { OpenAIEmbedding } from "llamaindex/embeddings/index";
import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex";
import {
KeywordExtractor,
QuestionsAnsweredExtractor,
SummaryExtractor,
TitleExtractor,
} from "llamaindex/extractors/index";
import { OpenAI } from "llamaindex/llm/openai";
import { SentenceSplitter } from "llamaindex/node-parser";
import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
import {
......
import { describe, expect, test } from "vitest";
// from unittest.mock import patch
import { OpenAI } from "llamaindex/llm/index";
import { OpenAI } from "llamaindex";
import { LLMSingleSelector } from "llamaindex/selectors/index";
import { mocStructuredkLlmGeneration } from "./utility/mockOpenAI.js";
......
......@@ -20,13 +20,13 @@ describe("SummaryIndex", () => {
let storageContext: StorageContext;
beforeAll(async () => {
storageContext = await storageContextFromDefaults({
persistDir: testDir,
});
const embedModel = new OpenAIEmbedding();
mockEmbeddingModel(embedModel);
Settings.embedModel = embedModel;
storageContext = await storageContextFromDefaults({
persistDir: testDir,
});
});
afterAll(() => {
......
......@@ -9,7 +9,7 @@ import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index";
import { mkdtemp, rm } from "node:fs/promises";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
import { afterAll, beforeAll, describe, expect, it, test, vi } from "vitest";
const testDir = await mkdtemp(join(tmpdir(), "test-"));
......@@ -24,6 +24,10 @@ describe("VectorStoreIndex", () => {
) => Promise<Array<number>>;
beforeAll(async () => {
const embedModel = new OpenAIEmbedding();
mockEmbeddingModel(embedModel);
Settings.embedModel = embedModel;
storageContext = await mockStorageContext(testDir);
testStrategy = async (
strategy: DocStoreStrategy,
......@@ -41,10 +45,6 @@ describe("VectorStoreIndex", () => {
}
return entries;
};
const embedModel = new OpenAIEmbedding();
mockEmbeddingModel(embedModel);
Settings.embedModel = embedModel;
});
afterAll(() => {
......@@ -65,3 +65,28 @@ describe("VectorStoreIndex", () => {
await rm(testDir, { recursive: true });
});
});
describe("[VectorStoreIndex] use embedding model", () => {
it("should use embedding model passed in options instead of Settings", async () => {
const documents = [new Document({ text: "This needs to be embedded" })];
// Create mock embedding models
const settingsEmbedModel = new OpenAIEmbedding();
const customEmbedModel = new OpenAIEmbedding();
// Mock the embedding models using the utility function
mockEmbeddingModel(settingsEmbedModel);
mockEmbeddingModel(customEmbedModel);
// Add spies to track calls
const settingsSpy = vi.spyOn(settingsEmbedModel, "getTextEmbeddings");
const customSpy = vi.spyOn(customEmbedModel, "getTextEmbeddings");
Settings.embedModel = settingsEmbedModel;
const storageContext = await mockStorageContext(testDir, customEmbedModel); // setup custom embedding model
await VectorStoreIndex.fromDocuments(documents, { storageContext });
expect(customSpy).toHaveBeenCalled();
expect(settingsSpy).not.toHaveBeenCalled();
});
});
import type { CallbackManager } from "@llamaindex/core/global";
import type { LLMChatParamsBase } from "llamaindex";
import { Settings } from "llamaindex";
import type { OpenAIEmbedding } from "llamaindex/embeddings/OpenAIEmbedding";
import { OpenAI } from "llamaindex/llm/openai";
import type { LLMChatParamsBase, OpenAIEmbedding } from "llamaindex";
import { OpenAI, Settings } from "llamaindex";
import { vi } from "vitest";
export const DEFAULT_LLM_TEXT_OUTPUT = "MOCK_TOKEN_1-MOCK_TOKEN_2";
......
import { OpenAIEmbedding, storageContextFromDefaults } from "llamaindex";
import {
BaseEmbedding,
OpenAIEmbedding,
storageContextFromDefaults,
} from "llamaindex";
import { mockEmbeddingModel } from "./mockOpenAI.js";
export async function mockStorageContext(testDir: string) {
export async function mockStorageContext(
testDir: string,
embeddingModel?: BaseEmbedding,
) {
const storageContext = await storageContextFromDefaults({
persistDir: testDir,
});
for (const store of Object.values(storageContext.vectorStores)) {
store.embedModel = new OpenAIEmbedding();
mockEmbeddingModel(store.embedModel as OpenAIEmbedding);
if (embeddingModel) {
// use embeddingModel if it is passed in
store.embedModel = embeddingModel;
} else {
// mock an embedding model for testing
store.embedModel = new OpenAIEmbedding();
mockEmbeddingModel(store.embedModel as OpenAIEmbedding);
}
}
return storageContext;
}
{
"name": "@llamaindex/deepseek",
"description": "DeepSeek Adapter for LlamaIndex",
"version": "0.0.1",
"type": "module",
"main": "./dist/index.cjs",
"module": "./dist/index.js",
"exports": {
".": {
"require": {
"types": "./dist/index.d.cts",
"default": "./dist/index.cjs"
},
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
}
}
},
"files": [
"dist"
],
"repository": {
"type": "git",
"url": "git+https://github.com/run-llama/LlamaIndexTS.git",
"directory": "packages/providers/deepseek"
},
"scripts": {
"build": "bunchee",
"dev": "bunchee --watch"
},
"devDependencies": {
"bunchee": "6.3.4"
},
"dependencies": {
"@llamaindex/env": "workspace:*",
"@llamaindex/openai": "workspace:*"
}
}
export * from "./llm";
{
"extends": "../../../tsconfig.json",
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"moduleResolution": "bundler",
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo"
},
"include": ["./src"],
"references": [
{
"path": "../openai/tsconfig.json"
},
{
"path": "../../env/tsconfig.json"
}
]
}
{
"name": "@llamaindex/fireworks",
"description": "Fireworks Adapter for LlamaIndex",
"version": "0.0.1",
"type": "module",
"main": "./dist/index.cjs",
"module": "./dist/index.js",
"exports": {
".": {
"require": {
"types": "./dist/index.d.cts",
"default": "./dist/index.cjs"
},
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
}
}
},
"files": [
"dist"
],
"repository": {
"type": "git",
"url": "git+https://github.com/run-llama/LlamaIndexTS.git",
"directory": "packages/providers/fireworks"
},
"scripts": {
"build": "bunchee",
"dev": "bunchee --watch"
},
"devDependencies": {
"bunchee": "6.3.4"
},
"dependencies": {
"@llamaindex/env": "workspace:*",
"@llamaindex/openai": "workspace:*"
}
}
export * from "./embedding";
export * from "./llm";
{
"extends": "../../../tsconfig.json",
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"moduleResolution": "bundler",
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo"
},
"include": ["./src"],
"references": [
{
"path": "../openai/tsconfig.json"
},
{
"path": "../../env/tsconfig.json"
}
]
}
{
"name": "@llamaindex/jinaai",
"description": "JinaAI Adapter for LlamaIndex",
"version": "0.0.1",
"type": "module",
"main": "./dist/index.cjs",
"module": "./dist/index.js",
"exports": {
".": {
"require": {
"types": "./dist/index.d.cts",
"default": "./dist/index.cjs"
},
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
}
}
},
"files": [
"dist"
],
"repository": {
"type": "git",
"url": "git+https://github.com/run-llama/LlamaIndexTS.git",
"directory": "packages/providers/jinaai"
},
"scripts": {
"build": "bunchee",
"dev": "bunchee --watch"
},
"devDependencies": {
"bunchee": "6.3.4"
},
"dependencies": {
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*",
"@llamaindex/openai": "workspace:*"
}
}
import { MultiModalEmbedding } from "@llamaindex/core/embeddings";
import type { ImageType } from "@llamaindex/core/schema";
import { imageToDataUrl } from "@llamaindex/core/utils";
import { getEnv } from "@llamaindex/env";
import { imageToDataUrl } from "../internal/utils.js";
import type { ImageType } from "../Node.js";
function isLocal(url: ImageType): boolean {
if (url instanceof Blob) return true;
......
export * from "./embedding";
{
"extends": "../../../tsconfig.json",
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"moduleResolution": "bundler",
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo"
},
"include": ["./src"],
"references": [
{
"path": "../openai/tsconfig.json"
},
{
"path": "../../env/tsconfig.json"
}
]
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment