diff --git a/.changeset/four-beers-kick.md b/.changeset/four-beers-kick.md new file mode 100644 index 0000000000000000000000000000000000000000..4e591e3c3512d5863249b94531dde56a557bb089 --- /dev/null +++ b/.changeset/four-beers-kick.md @@ -0,0 +1,12 @@ +--- +"@llamaindex/core": minor +"llamaindex": minor +--- + +refactor: move chat engine & retriever into core. + +This is a breaking change since `stream` option has moved to second parameter. + +- `chat` API in BaseChatEngine has changed, Move `stream` option into second parameter. +- `chatHistory` in BaseChatEngine now returns `ChatMessage[] | Promise<ChatMessage[]>`, instead of `BaseMemory` +- update `retrieve-end` type diff --git a/README.md b/README.md index 227d859122f9b3676a37a6b3d7789723315e66e0..48ba729b118ee049e0d7a223f8ff1430b8cd73f2 100644 --- a/README.md +++ b/README.md @@ -167,11 +167,13 @@ export async function chatWithAgent( // ... adding your tools here ], }); - const responseStream = await agent.chat({ - stream: true, - message: question, - chatHistory: prevMessages, - }); + const responseStream = await agent.chat( + { + message: question, + chatHistory: prevMessages, + }, + true, + ); const uiStream = createStreamableUI(<div>loading...</div>); responseStream .pipeTo( diff --git a/examples/agent/azure_dynamic_session.ts b/examples/agent/azure_dynamic_session.ts index 31d5375c7e223d53c85d529ceb8095a50938779a..72dbb77e639c96a696998e22d5ed0618eabf056b 100644 --- a/examples/agent/azure_dynamic_session.ts +++ b/examples/agent/azure_dynamic_session.ts @@ -42,7 +42,6 @@ async function main() { const response = await agent.chat({ message: "plot a chart of 5 random numbers and save it to /mnt/data/chart.png", - stream: false, }); // Print the response diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts index 4d8d6e8fcbd4cf3bc52f6923b630e2b17cdece09..19dfcca177cf2339f9111d2e9e9a68b74e73f34c 100644 --- a/examples/agent/stream_openai_agent.ts +++ b/examples/agent/stream_openai_agent.ts @@ -61,10 +61,12 @@ async function main() { tools: [functionTool, functionTool2], }); - const stream = await agent.chat({ - message: "Divide 16 by 2 then add 20", - stream: true, - }); + const stream = await agent.chat( + { + message: "Divide 16 by 2 then add 20", + }, + true, + ); console.log("Response:"); diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index e4100e9909a97f0b79c5e4d73eef8b0d77379f87..1ec98652b15bc748b6465f6d9b0ed7d22fb0f6bc 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -11,10 +11,12 @@ async function main() { }); // Chat with the agent - const response = await agent.chat({ - message: "Who was Goethe?", - stream: true, - }); + const response = await agent.chat( + { + message: "Who was Goethe?", + }, + true, + ); for await (const { delta } of response) { process.stdout.write(delta); diff --git a/examples/anthropic/chat_interactive.ts b/examples/anthropic/chat_interactive.ts index 4565c70e41b70d81313c4f16cd9c16946ec86903..9579cd7a4aa2137626fb1f7227d30c743a90ca60 100644 --- a/examples/anthropic/chat_interactive.ts +++ b/examples/anthropic/chat_interactive.ts @@ -18,14 +18,14 @@ import readline from "node:readline/promises"; }); const chatEngine = new SimpleChatEngine({ llm, - chatHistory, + memory: chatHistory, }); const rl = readline.createInterface({ input, output }); while (true) { const query = await rl.question("User: "); process.stdout.write("Assistant: "); - const stream = await chatEngine.chat({ message: query, stream: true }); + const stream = await chatEngine.chat({ message: query }, true); for await (const chunk of stream) { process.stdout.write(chunk.response); } diff --git a/examples/chatEngine.ts b/examples/chatEngine.ts index addd025bb83cc23c870b99752b3c1d85711b8108..da24e225156e5974150d61b5bc8ffa5b4ce7ce47 100644 --- a/examples/chatEngine.ts +++ b/examples/chatEngine.ts @@ -24,7 +24,7 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const stream = await chatEngine.chat({ message: query, stream: true }); + const stream = await chatEngine.chat({ message: query }, true); console.log(); for await (const chunk of stream) { process.stdout.write(chunk.response); diff --git a/examples/chatHistory.ts b/examples/chatHistory.ts index c55c618d69ac5eb64ee20812ef11bb10e67bc68f..14e12a31b02766e9a8d94121102b5cf42927162a 100644 --- a/examples/chatHistory.ts +++ b/examples/chatHistory.ts @@ -24,11 +24,13 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const stream = await chatEngine.chat({ - message: query, - chatHistory, - stream: true, - }); + const stream = await chatEngine.chat( + { + message: query, + chatHistory, + }, + true, + ); if (chatHistory.getLastSummary()) { // Print the summary of the conversation so far that is produced by the SummaryChatHistory console.log(`Summary: ${chatHistory.getLastSummary()?.content}`); diff --git a/examples/cloud/chat.ts b/examples/cloud/chat.ts index d6fdce572140601cb757dc577c799f2d00dabf43..c39224ae08bc8feffdd24cbc9e084c1f8a83735f 100644 --- a/examples/cloud/chat.ts +++ b/examples/cloud/chat.ts @@ -18,7 +18,7 @@ async function main() { while (true) { const query = await rl.question("User: "); - const stream = await chatEngine.chat({ message: query, stream: true }); + const stream = await chatEngine.chat({ message: query }, true); for await (const chunk of stream) { process.stdout.write(chunk.response); } diff --git a/examples/multimodal/context.ts b/examples/multimodal/context.ts index c4a28646cd05ff2be4f73376d50c9d65054403b1..bfe0c7daa92ac4afe15c1ca056d81c25a6f699c0 100644 --- a/examples/multimodal/context.ts +++ b/examples/multimodal/context.ts @@ -1,4 +1,5 @@ // call pnpm tsx multimodal/load.ts first to init the storage +import { extractText } from "@llamaindex/core/utils"; import { ContextChatEngine, NodeWithScore, @@ -25,8 +26,9 @@ Settings.callbackManager.on("retrieve-end", (event) => { const textNodes = nodes.filter( (node: NodeWithScore) => node.node.type === ObjectType.TEXT, ); + const text = extractText(query); console.log( - `Retrieved ${textNodes.length} text nodes and ${imageNodes.length} image nodes for query: ${query}`, + `Retrieved ${textNodes.length} text nodes and ${imageNodes.length} image nodes for query: ${text}`, ); }); diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index 8ac66ffa8cd6efe83cd287f5b7a263e097d7700a..bbe8a6296768aae7e22af47ba17011e84686c7c9 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -1,3 +1,4 @@ +import { extractText } from "@llamaindex/core/utils"; import { getResponseSynthesizer, OpenAI, @@ -16,7 +17,8 @@ Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 }); // Update callbackManager Settings.callbackManager.on("retrieve-end", (event) => { const { nodes, query } = event.detail; - console.log(`Retrieved ${nodes.length} nodes for query: ${query}`); + const text = extractText(query); + console.log(`Retrieved ${nodes.length} nodes for query: ${text}`); }); async function main() { diff --git a/packages/autotool/examples/02_nextjs/actions.ts b/packages/autotool/examples/02_nextjs/actions.ts index 3c5d12a48f49f7363987b2f3d13c6a4fc4e86930..b6738a2300b917ab2a491d5e6110c815c421ce6f 100644 --- a/packages/autotool/examples/02_nextjs/actions.ts +++ b/packages/autotool/examples/02_nextjs/actions.ts @@ -14,10 +14,12 @@ export async function chatWithAI(message: string): Promise<ReactNode> { const uiStream = createStreamableUI(); runWithStreamableUI(uiStream, () => agent - .chat({ - stream: true, - message, - }) + .chat( + { + message, + }, + true, + ) .then(async (responseStream) => { return responseStream.pipeTo( new WritableStream({ diff --git a/packages/core/package.json b/packages/core/package.json index 403fd607363116417e0538c0a2381ab223ad0f51..b7f799d8acb46ec12767d63033cd0f3568e1e043 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -199,6 +199,34 @@ "types": "./dist/response-synthesizers/index.d.ts", "default": "./dist/response-synthesizers/index.js" } + }, + "./chat-engine": { + "require": { + "types": "./dist/chat-engine/index.d.cts", + "default": "./dist/chat-engine/index.cjs" + }, + "import": { + "types": "./dist/chat-engine/index.d.ts", + "default": "./dist/chat-engine/index.js" + }, + "default": { + "types": "./dist/chat-engine/index.d.ts", + "default": "./dist/chat-engine/index.js" + } + }, + "./retriever": { + "require": { + "types": "./dist/retriever/index.d.cts", + "default": "./dist/retriever/index.cjs" + }, + "import": { + "types": "./dist/retriever/index.d.ts", + "default": "./dist/retriever/index.js" + }, + "default": { + "types": "./dist/retriever/index.d.ts", + "default": "./dist/retriever/index.js" + } } }, "files": [ diff --git a/packages/core/src/chat-engine/index.ts b/packages/core/src/chat-engine/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..07d7b1223eb6ee53bd0659eac9e79d8efd3e48d2 --- /dev/null +++ b/packages/core/src/chat-engine/index.ts @@ -0,0 +1,28 @@ +import type { ChatMessage, MessageContent } from "../llms"; +import type { BaseMemory } from "../memory"; +import { EngineResponse } from "../schema"; + +export interface ChatEngineParams< + AdditionalMessageOptions extends object = object, +> { + message: MessageContent; + /** + * Optional chat history if you want to customize the chat history. + */ + chatHistory?: + | ChatMessage<AdditionalMessageOptions>[] + | BaseMemory<AdditionalMessageOptions>; +} + +export abstract class BaseChatEngine { + abstract chat( + params: ChatEngineParams, + stream?: false, + ): Promise<EngineResponse>; + abstract chat( + params: ChatEngineParams, + stream: true, + ): Promise<AsyncIterable<EngineResponse>>; + + abstract chatHistory: ChatMessage[] | Promise<ChatMessage[]>; +} diff --git a/packages/core/src/global/settings/callback-manager.ts b/packages/core/src/global/settings/callback-manager.ts index 52ceba77212e18cfde738b9b504669ed2e1944e9..bce2f7237742be561e328927168640bb3075d059 100644 --- a/packages/core/src/global/settings/callback-manager.ts +++ b/packages/core/src/global/settings/callback-manager.ts @@ -11,6 +11,7 @@ import type { SynthesizeEndEvent, SynthesizeStartEvent, } from "../../response-synthesizers"; +import type { RetrieveEndEvent, RetrieveStartEvent } from "../../retriever"; import { TextNode } from "../../schema"; import { EventCaller, getEventCaller } from "../../utils"; import type { UUID } from "../type"; @@ -69,6 +70,8 @@ export interface LlamaIndexEventMaps { "query-end": QueryEndEvent; "synthesize-start": SynthesizeStartEvent; "synthesize-end": SynthesizeEndEvent; + "retrieve-start": RetrieveStartEvent; + "retrieve-end": RetrieveEndEvent; } export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> { diff --git a/packages/core/src/query-engine/base.ts b/packages/core/src/query-engine/base.ts index 6b9d0b08e8bb1ecdfe4b4c7aef654e71b9f39800..db079b28429ecbe664fbaded13d0fb3c3c6c4e3e 100644 --- a/packages/core/src/query-engine/base.ts +++ b/packages/core/src/query-engine/base.ts @@ -2,7 +2,7 @@ import { randomUUID } from "@llamaindex/env"; import { Settings } from "../global"; import type { MessageContent } from "../llms"; import { PromptMixin } from "../prompts"; -import { EngineResponse } from "../schema"; +import { EngineResponse, type NodeWithScore } from "../schema"; import { wrapEventCaller } from "../utils"; /** @@ -28,6 +28,12 @@ export abstract class BaseQueryEngine extends PromptMixin { super(); } + async retrieve(params: QueryType): Promise<NodeWithScore[]> { + throw new Error( + "This query engine does not support retrieve, use query directly", + ); + } + query( strOrQueryBundle: QueryType, stream: true, diff --git a/packages/core/src/retriever/index.ts b/packages/core/src/retriever/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..176a98c82df00e8608cbbb384cbad7dcba09a055 --- /dev/null +++ b/packages/core/src/retriever/index.ts @@ -0,0 +1,112 @@ +import { randomUUID } from "@llamaindex/env"; +import { Settings } from "../global"; +import type { MessageContent } from "../llms"; +import { PromptMixin } from "../prompts"; +import type { QueryBundle, QueryType } from "../query-engine"; +import { BaseNode, IndexNode, type NodeWithScore, ObjectType } from "../schema"; + +export type RetrieveParams = { + query: MessageContent; + preFilters?: unknown; +}; + +export type RetrieveStartEvent = { + id: string; + query: QueryBundle; +}; + +export type RetrieveEndEvent = { + id: string; + query: QueryBundle; + nodes: NodeWithScore[]; +}; + +export abstract class BaseRetriever extends PromptMixin { + objectMap: Map<string, unknown> = new Map(); + + protected _updatePrompts() {} + protected _getPrompts() { + return {}; + } + + protected _getPromptModules() { + return {}; + } + + protected constructor() { + super(); + } + + public async retrieve(params: QueryType): Promise<NodeWithScore[]> { + const cb = Settings.callbackManager; + const queryBundle = typeof params === "string" ? { query: params } : params; + const id = randomUUID(); + cb.dispatchEvent("retrieve-start", { id, query: queryBundle }); + let response = await this._retrieve(queryBundle); + response = await this._handleRecursiveRetrieval(queryBundle, response); + cb.dispatchEvent("retrieve-end", { + id, + query: queryBundle, + nodes: response, + }); + return response; + } + + abstract _retrieve(params: QueryBundle): Promise<NodeWithScore[]>; + + async _handleRecursiveRetrieval( + params: QueryBundle, + nodes: NodeWithScore[], + ): Promise<NodeWithScore[]> { + const retrievedNodes = []; + for (const { node, score = 1.0 } of nodes) { + if (node.type === ObjectType.INDEX) { + const indexNode = node as IndexNode; + const object = this.objectMap.get(indexNode.indexId); + if (object !== undefined) { + retrievedNodes.push( + ...this._retrieveFromObject(object, params, score), + ); + } else { + retrievedNodes.push({ node, score }); + } + } else { + retrievedNodes.push({ node, score }); + } + } + return nodes; + } + + _retrieveFromObject( + object: unknown, + queryBundle: QueryBundle, + score: number, + ): NodeWithScore[] { + if (object == null) { + throw new TypeError("Object is not retrievable"); + } + if (typeof object !== "object") { + throw new TypeError("Object is not retrievable"); + } + if ("node" in object && object.node instanceof BaseNode) { + return [ + { + node: object.node, + score: + "score" in object && typeof object.score === "number" + ? object.score + : score, + }, + ]; + } + if (object instanceof BaseNode) { + return [{ node: object, score }]; + } else { + // todo: support other types + // BaseQueryEngine + // BaseRetriever + // QueryComponent + throw new TypeError("Object is not retrievable"); + } + } +} diff --git a/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts b/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts index 8a283cc6bb1d734470426ba84d9c39cdc79c618d..ce80b00b2b830c84ff0bcd96b7aec11e9cd7a88b 100644 --- a/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts +++ b/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts @@ -11,10 +11,12 @@ export default { tools: [], }); console.log(1); - const responseStream = await agent.chat({ - stream: true, - message: "Hello? What is the weather today?", - }); + const responseStream = await agent.chat( + { + message: "Hello? What is the weather today?", + }, + true, + ); console.log(2); const textEncoder = new TextEncoder(); const response = responseStream.pipeThrough<Uint8Array>( diff --git a/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx b/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx index b71e52a928f4241ee28280fa63bce6e9a612c7d1..f648a8d9d473cddbe0387795506509fb3d1e6c19 100644 --- a/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx +++ b/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx @@ -10,11 +10,13 @@ export async function chatWithAgent( const agent = new OpenAIAgent({ tools: [], }); - const responseStream = await agent.chat({ - stream: true, - message: question, - chatHistory: prevMessages, - }); + const responseStream = await agent.chat( + { + message: question, + chatHistory: prevMessages, + }, + true, + ); const uiStream = createStreamableUI(<div>loading...</div>); responseStream .pipeTo( diff --git a/packages/llamaindex/e2e/node/openai.e2e.ts b/packages/llamaindex/e2e/node/openai.e2e.ts index 4019390854c1059306639e61b99a542a695695c1..eb9918eef5bc946b60a45d8aaff263754ca411e5 100644 --- a/packages/llamaindex/e2e/node/openai.e2e.ts +++ b/packages/llamaindex/e2e/node/openai.e2e.ts @@ -322,10 +322,12 @@ await test("agent stream", async (t) => { tools: [sumNumbersTool, divideNumbersTool], }); - const stream = await agent.chat({ - message: "Divide 16 by 2 then add 20", - stream: true, - }); + const stream = await agent.chat( + { + message: "Divide 16 by 2 then add 20", + }, + true, + ); let message = ""; diff --git a/packages/llamaindex/e2e/node/react.e2e.ts b/packages/llamaindex/e2e/node/react.e2e.ts index c0bb8c46c9b399b1bf32749fefcda4a65a9fb24a..58170a57f4ecdc03069c7df9fd05bf101dea3155 100644 --- a/packages/llamaindex/e2e/node/react.e2e.ts +++ b/packages/llamaindex/e2e/node/react.e2e.ts @@ -20,7 +20,6 @@ await test("react agent", async (t) => { tools: [getWeatherTool], }); const response = await agent.chat({ - stream: false, message: "What is the weather like in San Francisco?", }); @@ -35,10 +34,12 @@ await test("react agent stream", async (t) => { tools: [getWeatherTool], }); - const stream = await agent.chat({ - stream: true, - message: "What is the weather like in San Francisco?", - }); + const stream = await agent.chat( + { + message: "What is the weather like in San Francisco?", + }, + true, + ); let content = ""; for await (const response of stream) { diff --git a/packages/llamaindex/src/Retriever.ts b/packages/llamaindex/src/Retriever.ts deleted file mode 100644 index b7ef4cf0462b23ec9674e865d2e1450780aa1a9d..0000000000000000000000000000000000000000 --- a/packages/llamaindex/src/Retriever.ts +++ /dev/null @@ -1,20 +0,0 @@ -import type { NodeWithScore } from "@llamaindex/core/schema"; -import type { ServiceContext } from "./ServiceContext.js"; -import type { MessageContent } from "./index.edge.js"; - -export type RetrieveParams = { - query: MessageContent; - preFilters?: unknown; -}; - -/** - * Retrievers retrieve the nodes that most closely match our query in similarity. - */ -export interface BaseRetriever { - retrieve(params: RetrieveParams): Promise<NodeWithScore[]>; - - /** - * @deprecated to be deprecated soon - */ - serviceContext?: ServiceContext | undefined; -} diff --git a/packages/llamaindex/src/agent/anthropic.ts b/packages/llamaindex/src/agent/anthropic.ts index 8f17b360d30e8196709edc2e6d2cbc51bfc1dad2..e8c827b685962977aca70164c1bd34995b885288 100644 --- a/packages/llamaindex/src/agent/anthropic.ts +++ b/packages/llamaindex/src/agent/anthropic.ts @@ -1,9 +1,6 @@ +import type { ChatEngineParams } from "@llamaindex/core/chat-engine"; +import type { EngineResponse } from "@llamaindex/core/schema"; import { Settings } from "../Settings.js"; -import type { - ChatEngineParamsNonStreaming, - ChatEngineParamsStreaming, - EngineResponse, -} from "../index.edge.js"; import { Anthropic } from "../llm/anthropic.js"; import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; @@ -24,12 +21,10 @@ export class AnthropicAgent extends LLMAgent { }); } - async chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; - async chat(params: ChatEngineParamsStreaming): Promise<never>; - override async chat( - params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, - ) { - if (params.stream) { + async chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; + async chat(params: ChatEngineParams, stream: true): Promise<never>; + override async chat(params: ChatEngineParams, stream?: boolean) { + if (stream) { // Anthropic does support this, but looks like it's not supported in the LITS LLM throw new Error("Anthropic does not support streaming"); } diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index 2715b4ffebf357876ac3a421cf23b91e6fcfbed2..8141ceac42178414d2472d9334d02172a836664a 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -1,3 +1,7 @@ +import { + BaseChatEngine, + type ChatEngineParams, +} from "@llamaindex/core/chat-engine"; import type { BaseToolWithCall, ChatMessage, @@ -10,11 +14,6 @@ import { EngineResponse } from "@llamaindex/core/schema"; import { wrapEventCaller } from "@llamaindex/core/utils"; import { randomUUID } from "@llamaindex/env"; import { Settings } from "../Settings.js"; -import { - type ChatEngine, - type ChatEngineParamsNonStreaming, - type ChatEngineParamsStreaming, -} from "../engines/chat/index.js"; import { consoleLogger, emptyLogger } from "../internal/logger.js"; import { isReadableStream } from "../internal/utils.js"; import { ObjectRetriever } from "../objects/index.js"; @@ -207,8 +206,7 @@ export abstract class AgentRunner< > ? AdditionalMessageOptions : never, -> implements ChatEngine -{ +> extends BaseChatEngine { readonly #llm: AI; readonly #tools: | BaseToolWithCall[] @@ -259,6 +257,7 @@ export abstract class AgentRunner< protected constructor( params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>, ) { + super(); const { llm, chatHistory, systemPrompt, runner, tools, verbose } = params; this.#llm = llm; this.#chatHistory = chatHistory; @@ -345,13 +344,15 @@ export abstract class AgentRunner< }); } - async chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; + async chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; async chat( - params: ChatEngineParamsStreaming, + params: ChatEngineParams, + stream: true, ): Promise<ReadableStream<EngineResponse>>; @wrapEventCaller async chat( - params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, + params: ChatEngineParams, + stream?: boolean, ): Promise<EngineResponse | ReadableStream<EngineResponse>> { let chatHistory: ChatMessage<AdditionalMessageOptions>[] = []; @@ -363,12 +364,7 @@ export abstract class AgentRunner< params.chatHistory as ChatMessage<AdditionalMessageOptions>[]; } - const task = this.createTask( - params.message, - !!params.stream, - false, - chatHistory, - ); + const task = this.createTask(params.message, !!stream, false, chatHistory); for await (const stepOutput of task) { // update chat history for each round this.#chatHistory = [...stepOutput.taskStep.context.store.messages]; diff --git a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts index 58018917440d44b0df16c51dae7042386fdc1eb7..e3509b2968c26bdbc2a99a2c18c2401933a9c0ee 100644 --- a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts +++ b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts @@ -1,7 +1,6 @@ import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; import type { Document, TransformComponent } from "@llamaindex/core/schema"; -import type { BaseRetriever } from "../Retriever.js"; import { RetrieverQueryEngine } from "../engines/query/RetrieverQueryEngine.js"; import type { BaseNodePostprocessor } from "../postprocessors/types.js"; import type { CloudRetrieveParams } from "./LlamaCloudRetriever.js"; @@ -12,6 +11,7 @@ import { getAppBaseUrl, getProjectId, initService } from "./utils.js"; import { PipelinesService, ProjectsService } from "@llamaindex/cloud/api"; import { SentenceSplitter } from "@llamaindex/core/node-parser"; +import type { BaseRetriever } from "@llamaindex/core/retriever"; import { getEnv } from "@llamaindex/env"; import { OpenAIEmbedding } from "@llamaindex/openai"; import { Settings } from "../Settings.js"; diff --git a/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts b/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts index 3d41ae76c8758557601b3cfe5c8e52e5eeec592b..ebdf0b0bb8767d11df8fdd318da53e5b8019cdda 100644 --- a/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts +++ b/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts @@ -4,11 +4,12 @@ import { type RetrievalParams, type TextNodeWithScore, } from "@llamaindex/cloud/api"; -import { DEFAULT_PROJECT_NAME, Settings } from "@llamaindex/core/global"; +import { DEFAULT_PROJECT_NAME } from "@llamaindex/core/global"; +import type { QueryBundle } from "@llamaindex/core/query-engine"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import type { NodeWithScore } from "@llamaindex/core/schema"; import { jsonToNode, ObjectType } from "@llamaindex/core/schema"; -import { extractText, wrapEventCaller } from "@llamaindex/core/utils"; -import type { BaseRetriever, RetrieveParams } from "../Retriever.js"; +import { extractText } from "@llamaindex/core/utils"; import type { ClientParams, CloudConstructorParams } from "./type.js"; import { getProjectId, initService } from "./utils.js"; @@ -17,7 +18,7 @@ export type CloudRetrieveParams = Omit< "query" | "search_filters" | "dense_similarity_top_k" > & { similarityTopK?: number; filters?: MetadataFilters }; -export class LlamaCloudRetriever implements BaseRetriever { +export class LlamaCloudRetriever extends BaseRetriever { clientParams: ClientParams; retrieveParams: CloudRetrieveParams; organizationId?: string; @@ -42,6 +43,7 @@ export class LlamaCloudRetriever implements BaseRetriever { } constructor(params: CloudConstructorParams & CloudRetrieveParams) { + super(); this.clientParams = { apiKey: params.apiKey, baseUrl: params.baseUrl }; initService(this.clientParams); this.retrieveParams = params; @@ -54,11 +56,7 @@ export class LlamaCloudRetriever implements BaseRetriever { } } - @wrapEventCaller - async retrieve({ - query, - preFilters, - }: RetrieveParams): Promise<NodeWithScore[]> { + async _retrieve(query: QueryBundle): Promise<NodeWithScore[]> { const { data: pipelines } = await PipelinesService.searchPipelinesApiV1PipelinesGet({ query: { @@ -97,19 +95,11 @@ export class LlamaCloudRetriever implements BaseRetriever { body: { ...this.retrieveParams, query: extractText(query), - search_filters: - this.retrieveParams.filters ?? (preFilters as MetadataFilters), + search_filters: this.retrieveParams.filters as MetadataFilters, dense_similarity_top_k: this.retrieveParams.similarityTopK!, }, }); - const nodesWithScores = this.resultNodesToNodeWithScore( - results.retrieval_nodes, - ); - Settings.callbackManager.dispatchEvent("retrieve-end", { - query, - nodes: nodesWithScores, - }); - return nodesWithScores; + return this.resultNodesToNodeWithScore(results.retrieval_nodes); } } diff --git a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts index 2be3eb548a5595fe61772628dfffc94612a959a1..4ebff2f40e16d3160ce7818ad0ade71ef0772baf 100644 --- a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts @@ -1,10 +1,13 @@ +import { + BaseChatEngine, + type ChatEngineParams, +} from "@llamaindex/core/chat-engine"; import type { ChatMessage, LLM } from "@llamaindex/core/llms"; import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; import { type CondenseQuestionPrompt, defaultCondenseQuestionPrompt, type ModuleRecord, - PromptMixin, } from "@llamaindex/core/prompts"; import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; import type { EngineResponse } from "@llamaindex/core/schema"; @@ -16,11 +19,6 @@ import { } from "@llamaindex/core/utils"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; -import type { - ChatEngine, - ChatEngineParamsNonStreaming, - ChatEngineParamsStreaming, -} from "./types.js"; /** * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). @@ -32,16 +30,16 @@ import type { * underlying data. It performs less well when the chat messages are not questions about the * data, or are very referential to previous context. */ - -export class CondenseQuestionChatEngine - extends PromptMixin - implements ChatEngine -{ +export class CondenseQuestionChatEngine extends BaseChatEngine { queryEngine: BaseQueryEngine; - chatHistory: BaseMemory; + memory: BaseMemory; llm: LLM; condenseMessagePrompt: CondenseQuestionPrompt; + get chatHistory() { + return this.memory.getMessages(); + } + constructor(init: { queryEngine: BaseQueryEngine; chatHistory: ChatMessage[]; @@ -51,7 +49,7 @@ export class CondenseQuestionChatEngine super(); this.queryEngine = init.queryEngine; - this.chatHistory = new ChatMemoryBuffer({ + this.memory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory, }); this.llm = llmFromSettingsOrContext(init?.serviceContext); @@ -88,15 +86,17 @@ export class CondenseQuestionChatEngine }); } + chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; chat( - params: ChatEngineParamsStreaming, + params: ChatEngineParams, + stream: true, ): Promise<AsyncIterable<EngineResponse>>; - chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; @wrapEventCaller async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + params: ChatEngineParams, + stream = false, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message, stream } = params; + const { message } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ chatHistory: @@ -104,7 +104,7 @@ export class CondenseQuestionChatEngine ? await params.chatHistory.getMessages() : params.chatHistory, }) - : this.chatHistory; + : this.memory; const condensedQuestion = ( await this.condenseQuestion(chatHistory, extractText(message)) @@ -140,6 +140,6 @@ export class CondenseQuestionChatEngine } reset() { - this.chatHistory.reset(); + this.memory.reset(); } } diff --git a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts index 23d56a00a324fb11d3853d500e8701a9b81b5229..26baddf9c9b6ae26633e7a92bdce90a3d252784a 100644 --- a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts @@ -1,3 +1,7 @@ +import type { + BaseChatEngine, + ChatEngineParams, +} from "@llamaindex/core/chat-engine"; import type { ChatMessage, LLM, @@ -11,6 +15,7 @@ import { PromptMixin, type PromptsRecord, } from "@llamaindex/core/prompts"; +import type { BaseRetriever } from "@llamaindex/core/retriever"; import { EngineResponse, MetadataMode } from "@llamaindex/core/schema"; import { extractText, @@ -18,27 +23,25 @@ import { streamReducer, wrapEventCaller, } from "@llamaindex/core/utils"; -import type { BaseRetriever } from "../../Retriever.js"; import { Settings } from "../../Settings.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import { DefaultContextGenerator } from "./DefaultContextGenerator.js"; -import type { - ChatEngine, - ChatEngineParamsNonStreaming, - ChatEngineParamsStreaming, - ContextGenerator, -} from "./types.js"; +import type { ContextGenerator } from "./types.js"; /** * ContextChatEngine uses the Index to get the appropriate context for each query. * The context is stored in the system prompt, and the chat history is chunk: ChatResponseChunk, nodes?: NodeWithScore<import("/Users/marcus/code/llamaindex/LlamaIndexTS/packages/core/src/Node").Metadata>[], nodes?: NodeWithScore<import("/Users/marcus/code/llamaindex/LlamaIndexTS/packages/core/src/Node").Metadata>[]lowing the appropriate context to be surfaced for each query. */ -export class ContextChatEngine extends PromptMixin implements ChatEngine { +export class ContextChatEngine extends PromptMixin implements BaseChatEngine { chatModel: LLM; - chatHistory: BaseMemory; + memory: BaseMemory; contextGenerator: ContextGenerator & PromptMixin; systemPrompt?: string | undefined; + get chatHistory() { + return this.memory.getMessages(); + } + constructor(init: { retriever: BaseRetriever; chatModel?: LLM | undefined; @@ -50,7 +53,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }) { super(); this.chatModel = init.chatModel ?? Settings.llm; - this.chatHistory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory }); + this.memory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory }); this.contextGenerator = new DefaultContextGenerator({ retriever: init.retriever, contextSystemPrompt: init?.contextSystemPrompt, @@ -79,15 +82,17 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }; } + chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; chat( - params: ChatEngineParamsStreaming, + params: ChatEngineParams, + stream: true, ): Promise<AsyncIterable<EngineResponse>>; - chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; @wrapEventCaller async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + params: ChatEngineParams, + stream = false, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message, stream } = params; + const { message } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ chatHistory: @@ -95,7 +100,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { ? await params.chatHistory.getMessages() : params.chatHistory, }) - : this.chatHistory; + : this.memory; const requestMessages = await this.prepareRequestMessages( message, chatHistory, @@ -125,7 +130,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { } reset() { - this.chatHistory.reset(); + this.memory.reset(); } private async prepareRequestMessages( diff --git a/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts b/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts index 976400556d26b5bd5a953dccf5132b9a6be0c6e3..c54e7d03e4c94d275e37af0c02d8821eb451eabc 100644 --- a/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/llamaindex/src/engines/chat/DefaultContextGenerator.ts @@ -5,10 +5,10 @@ import { type ModuleRecord, PromptMixin, } from "@llamaindex/core/prompts"; +import type { BaseRetriever } from "@llamaindex/core/retriever"; import { MetadataMode, type NodeWithScore } from "@llamaindex/core/schema"; import { createMessageContent } from "@llamaindex/core/utils"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; -import type { BaseRetriever } from "../../Retriever.js"; import type { Context, ContextGenerator } from "./types.js"; export class DefaultContextGenerator diff --git a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts index 5fba0250144edac7e32a1a5daa9baf8efde7049b..d3dc9f5d0166041cbd2a6cb5e20691411c82ccad 100644 --- a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts @@ -1,3 +1,7 @@ +import type { + BaseChatEngine, + ChatEngineParams, +} from "@llamaindex/core/chat-engine"; import type { LLM } from "@llamaindex/core/llms"; import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; import { EngineResponse } from "@llamaindex/core/schema"; @@ -7,34 +11,35 @@ import { wrapEventCaller, } from "@llamaindex/core/utils"; import { Settings } from "../../Settings.js"; -import type { - ChatEngine, - ChatEngineParamsNonStreaming, - ChatEngineParamsStreaming, -} from "./types.js"; /** * SimpleChatEngine is the simplest possible chat engine. Useful for using your own custom prompts. */ -export class SimpleChatEngine implements ChatEngine { - chatHistory: BaseMemory; +export class SimpleChatEngine implements BaseChatEngine { + memory: BaseMemory; llm: LLM; + get chatHistory() { + return this.memory.getMessages(); + } + constructor(init?: Partial<SimpleChatEngine>) { - this.chatHistory = init?.chatHistory ?? new ChatMemoryBuffer(); + this.memory = init?.memory ?? new ChatMemoryBuffer(); this.llm = init?.llm ?? Settings.llm; } + chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; chat( - params: ChatEngineParamsStreaming, + params: ChatEngineParams, + stream: true, ): Promise<AsyncIterable<EngineResponse>>; - chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; @wrapEventCaller async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + params: ChatEngineParams, + stream = false, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message, stream } = params; + const { message } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ @@ -43,7 +48,7 @@ export class SimpleChatEngine implements ChatEngine { ? await params.chatHistory.getMessages() : params.chatHistory, }) - : this.chatHistory; + : this.memory; chatHistory.put({ content: message, role: "user" }); if (stream) { @@ -73,6 +78,6 @@ export class SimpleChatEngine implements ChatEngine { } reset() { - this.chatHistory.reset(); + this.memory.reset(); } } diff --git a/packages/llamaindex/src/engines/chat/types.ts b/packages/llamaindex/src/engines/chat/types.ts index 9b3b18c0bf8bc804b6a6a53a0ddfd09e1a67f3eb..2c5a811d1cfb619fb9a54b58223858bf4f2d4e8e 100644 --- a/packages/llamaindex/src/engines/chat/types.ts +++ b/packages/llamaindex/src/engines/chat/types.ts @@ -1,58 +1,10 @@ -import type { ChatMessage, MessageContent } from "@llamaindex/core/llms"; -import type { BaseMemory } from "@llamaindex/core/memory"; -import { EngineResponse, type NodeWithScore } from "@llamaindex/core/schema"; - -/** - * Represents the base parameters for ChatEngine. - */ -export interface ChatEngineParamsBase { - message: MessageContent; - /** - * Optional chat history if you want to customize the chat history. - */ - chatHistory?: ChatMessage[] | BaseMemory; - /** - * Optional flag to enable verbose mode. - * @default false - */ - verbose?: boolean; -} - -export interface ChatEngineParamsStreaming extends ChatEngineParamsBase { - stream: true; -} - -export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { - stream?: false | null; -} - -/** - * A ChatEngine is used to handle back and forth chats between the application and the LLM. - */ -export interface ChatEngine< - // synchronous response - R = EngineResponse, - // asynchronous response - AR extends AsyncIterable<unknown> = AsyncIterable<R>, -> { - /** - * Send message along with the class's current chat history to the LLM. - * @param params - */ - chat(params: ChatEngineParamsStreaming): Promise<AR>; - chat(params: ChatEngineParamsNonStreaming): Promise<R>; - - /** - * Resets the chat history so that it's empty. - */ - reset(): void; -} +import type { ChatMessage } from "@llamaindex/core/llms"; +import type { NodeWithScore } from "@llamaindex/core/schema"; export interface Context { message: ChatMessage; nodes: NodeWithScore[]; } - /** * A ContextGenerator is used to generate a context based on a message's text content */ diff --git a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts index aebe09dd6b0fc12ab1b0ea112ac9a90aac855723..ab1906e0751e1834c7e026cbfe335db1637671fa 100644 --- a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts @@ -1,10 +1,11 @@ -import { BaseQueryEngine } from "@llamaindex/core/query-engine"; +import type { MessageContent } from "@llamaindex/core/llms"; +import { BaseQueryEngine, type QueryType } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; import { getResponseSynthesizer } from "@llamaindex/core/response-synthesizers"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import { type NodeWithScore } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; -import type { BaseRetriever } from "../../Retriever.js"; /** * A query engine that uses a retriever to query an index and then synthesizes the response. @@ -67,7 +68,10 @@ export class RetrieverQueryEngine extends BaseQueryEngine { }; } - private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + private async applyNodePostprocessors( + nodes: NodeWithScore[], + query: MessageContent, + ) { let nodesWithScore = nodes; for (const postprocessor of this.nodePostprocessors) { @@ -80,12 +84,10 @@ export class RetrieverQueryEngine extends BaseQueryEngine { return nodesWithScore; } - private async retrieve(query: string) { - const nodes = await this.retriever.retrieve({ - query, - preFilters: this.preFilters, - }); + override async retrieve(query: QueryType) { + const nodes = await this.retriever.retrieve(query); - return await this.applyNodePostprocessors(nodes, query); + const messageContent = typeof query === "string" ? query : query.query; + return await this.applyNodePostprocessors(nodes, messageContent); } } diff --git a/packages/llamaindex/src/index.edge.ts b/packages/llamaindex/src/index.edge.ts index f37f249a2e0729e304360adf49c7fd4a8dc91381..a2b81943026a6bfa13d7cd741708ddacd8e99ae7 100644 --- a/packages/llamaindex/src/index.edge.ts +++ b/packages/llamaindex/src/index.edge.ts @@ -1,6 +1,6 @@ import type { AgentEndEvent, AgentStartEvent } from "./agent/types.js"; -import type { RetrievalEndEvent, RetrievalStartEvent } from "./llm/types.js"; +export * from "@llamaindex/core/chat-engine"; export { CallbackManager, DEFAULT_BASE_URL, @@ -35,12 +35,11 @@ export * from "@llamaindex/core/llms"; export * from "@llamaindex/core/prompts"; export * from "@llamaindex/core/query-engine"; export * from "@llamaindex/core/response-synthesizers"; +export * from "@llamaindex/core/retriever"; export * from "@llamaindex/core/schema"; declare module "@llamaindex/core/global" { export interface LlamaIndexEventMaps { - "retrieve-start": RetrievalStartEvent; - "retrieve-end": RetrievalEndEvent; // agent events "agent-start": AgentStartEvent; "agent-end": AgentEndEvent; @@ -66,7 +65,6 @@ export * from "./objects/index.js"; export * from "./OutputParser.js"; export * from "./postprocessors/index.js"; export * from "./QuestionGenerator.js"; -export * from "./Retriever.js"; export * from "./selectors/index.js"; export * from "./ServiceContext.js"; export { Settings } from "./Settings.js"; diff --git a/packages/llamaindex/src/indices/BaseIndex.ts b/packages/llamaindex/src/indices/BaseIndex.ts index 3d5d55c2cf7c650121114ea55eff3cf4780744f2..c5beb5d47383810e666a84de58bb769941ed00b8 100644 --- a/packages/llamaindex/src/indices/BaseIndex.ts +++ b/packages/llamaindex/src/indices/BaseIndex.ts @@ -1,7 +1,7 @@ import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; +import type { BaseRetriever } from "@llamaindex/core/retriever"; import type { BaseNode, Document } from "@llamaindex/core/schema"; -import type { BaseRetriever } from "../Retriever.js"; import type { ServiceContext } from "../ServiceContext.js"; import { nodeParserFromSettingsOrContext } from "../Settings.js"; import { runTransformations } from "../ingestion/IngestionPipeline.js"; diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index 6b326317acf223bb97a6a994df6fe49663573175..911850616040b7bbdaf65814007606a4dee73a8c 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -5,7 +5,6 @@ import type { NodeWithScore, } from "@llamaindex/core/schema"; import { MetadataMode } from "@llamaindex/core/schema"; -import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { serviceContextFromDefaults } from "../../ServiceContext.js"; import { RetrieverQueryEngine } from "../../engines/query/index.js"; @@ -29,7 +28,11 @@ import { type KeywordExtractPrompt, type QueryKeywordExtractPrompt, } from "@llamaindex/core/prompts"; -import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; +import type { + BaseQueryEngine, + QueryBundle, +} from "@llamaindex/core/query-engine"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import { extractText } from "@llamaindex/core/utils"; import { llmFromSettingsOrContext } from "../../Settings.js"; @@ -48,7 +51,7 @@ export enum KeywordTableRetrieverMode { } // Base Keyword Table Retriever -abstract class BaseKeywordTableRetriever implements BaseRetriever { +abstract class BaseKeywordTableRetriever extends BaseRetriever { protected index: KeywordTableIndex; protected indexStruct: KeywordTable; protected docstore: BaseDocumentStore; @@ -72,6 +75,7 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever { maxKeywordsPerQuery: number; numChunksPerQuery: number; }) { + super(); this.index = index; this.indexStruct = index.indexStruct; this.docstore = index.docStore; @@ -87,7 +91,7 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever { abstract getKeywords(query: string): Promise<string[]>; - async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> { + async _retrieve(query: QueryBundle): Promise<NodeWithScore[]> { const keywords = await this.getKeywords(extractText(query)); const chunkIndicesCount: { [key: string]: number } = {}; const filteredKeywords = keywords.filter((keyword) => diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index 375af4a6a2ca875fa632d96c78a9ec1b9497ed85..c449a129790835132175dfd92708b58a60049e63 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -2,16 +2,17 @@ import { type ChoiceSelectPrompt, defaultChoiceSelectPrompt, } from "@llamaindex/core/prompts"; +import type { QueryBundle } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; import { getResponseSynthesizer } from "@llamaindex/core/response-synthesizers"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import type { BaseNode, Document, NodeWithScore, } from "@llamaindex/core/schema"; -import { extractText, wrapEventCaller } from "@llamaindex/core/utils"; +import { extractText } from "@llamaindex/core/utils"; import _ from "lodash"; -import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext, @@ -279,15 +280,15 @@ export type ListRetrieverMode = SummaryRetrieverMode; /** * Simple retriever for SummaryIndex that returns all nodes */ -export class SummaryIndexRetriever implements BaseRetriever { +export class SummaryIndexRetriever extends BaseRetriever { index: SummaryIndex; constructor(index: SummaryIndex) { + super(); this.index = index; } - @wrapEventCaller - async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> { + async _retrieve(queryBundle: QueryBundle): Promise<NodeWithScore[]> { const nodeIds = this.index.indexStruct.nodes; const nodes = await this.index.docStore.getNodes(nodeIds); return nodes.map((node) => ({ @@ -300,7 +301,7 @@ export class SummaryIndexRetriever implements BaseRetriever { /** * LLM retriever for SummaryIndex which lets you select the most relevant chunks. */ -export class SummaryIndexLLMRetriever implements BaseRetriever { +export class SummaryIndexLLMRetriever extends BaseRetriever { index: SummaryIndex; choiceSelectPrompt: ChoiceSelectPrompt; choiceBatchSize: number; @@ -317,6 +318,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction, serviceContext?: ServiceContext, ) { + super(); this.index = index; this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt; this.choiceBatchSize = choiceBatchSize; @@ -326,7 +328,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { this.serviceContext = serviceContext || index.serviceContext; } - async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> { + async _retrieve(query: QueryBundle): Promise<NodeWithScore[]> { const nodeIds = this.index.indexStruct.nodes; const results: NodeWithScore[] = []; diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index c8a5bbd9ac1068b7d52b88ca4639e955b73a6527..8b88bbe4115e56cce918aaf1be46b627ca682ad3 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -2,9 +2,10 @@ import { DEFAULT_SIMILARITY_TOP_K, type BaseEmbedding, } from "@llamaindex/core/embeddings"; -import { Settings } from "@llamaindex/core/global"; import type { MessageContent } from "@llamaindex/core/llms"; +import type { QueryBundle } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import { ImageNode, ModalityType, @@ -14,8 +15,6 @@ import { type Document, type NodeWithScore, } from "@llamaindex/core/schema"; -import { wrapEventCaller } from "@llamaindex/core/utils"; -import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { nodeParserFromSettingsOrContext } from "../../Settings.js"; import { RetrieverQueryEngine } from "../../engines/query/RetrieverQueryEngine.js"; @@ -388,7 +387,7 @@ export type VectorIndexRetrieverOptions = { filters?: MetadataFilters; }; -export class VectorIndexRetriever implements BaseRetriever { +export class VectorIndexRetriever extends BaseRetriever { index: VectorStoreIndex; topK: TopKMap; @@ -401,6 +400,7 @@ export class VectorIndexRetriever implements BaseRetriever { topK, filters, }: VectorIndexRetrieverOptions) { + super(); this.index = index; this.serviceContext = this.index.serviceContext; this.topK = topK ?? { @@ -417,32 +417,17 @@ export class VectorIndexRetriever implements BaseRetriever { this.topK[ModalityType.TEXT] = similarityTopK; } - @wrapEventCaller - async retrieve({ - query, - preFilters, - }: RetrieveParams): Promise<NodeWithScore[]> { - Settings.callbackManager.dispatchEvent("retrieve-start", { - query, - }); + async _retrieve(params: QueryBundle): Promise<NodeWithScore[]> { + const { query } = params; const vectorStores = this.index.vectorStores; let nodesWithScores: NodeWithScore[] = []; for (const type in vectorStores) { const vectorStore: VectorStore = vectorStores[type as ModalityType]!; nodesWithScores = nodesWithScores.concat( - await this.retrieveQuery( - query, - type as ModalityType, - vectorStore, - preFilters as MetadataFilters, - ), + await this.retrieveQuery(query, type as ModalityType, vectorStore), ); } - Settings.callbackManager.dispatchEvent("retrieve-end", { - query, - nodes: nodesWithScores, - }); return nodesWithScores; } diff --git a/packages/llamaindex/src/llm/types.ts b/packages/llamaindex/src/llm/types.ts index c947d80c488ab55ee9dd17f985acd48a3358e722..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/packages/llamaindex/src/llm/types.ts +++ b/packages/llamaindex/src/llm/types.ts @@ -1,10 +0,0 @@ -import type { MessageContent } from "@llamaindex/core/llms"; -import type { NodeWithScore } from "@llamaindex/core/schema"; - -export type RetrievalStartEvent = { - query: MessageContent; -}; -export type RetrievalEndEvent = { - query: MessageContent; - nodes: NodeWithScore[]; -}; diff --git a/packages/llamaindex/src/objects/base.ts b/packages/llamaindex/src/objects/base.ts index 5737a351fa1022fe48dfcab16f1663d790a0c029..eb7cb0083d392b2006341625a741e567ba2a0dfa 100644 --- a/packages/llamaindex/src/objects/base.ts +++ b/packages/llamaindex/src/objects/base.ts @@ -1,8 +1,8 @@ import type { BaseTool, MessageContent } from "@llamaindex/core/llms"; +import { BaseRetriever } from "@llamaindex/core/retriever"; import type { BaseNode, Metadata } from "@llamaindex/core/schema"; import { TextNode } from "@llamaindex/core/schema"; import { extractText } from "@llamaindex/core/utils"; -import type { BaseRetriever } from "../Retriever.js"; import type { VectorStoreIndex } from "../indices/vectorStore/index.js"; // Assuming that necessary interfaces and classes (like OT, TextNode, BaseNode, etc.) are defined elsewhere @@ -49,9 +49,6 @@ export abstract class BaseObjectNodeMapping { // You will need to implement specific subclasses of BaseObjectNodeMapping as per your project requirements. -// todo: multimodal support -type QueryType = MessageContent; - export class ObjectRetriever<T = unknown> { _retriever: BaseRetriever; _objectNodeMapping: BaseObjectNodeMapping; @@ -70,7 +67,7 @@ export class ObjectRetriever<T = unknown> { } // Translating the retrieve method - async retrieve(strOrQueryBundle: QueryType): Promise<T[]> { + async retrieve(strOrQueryBundle: MessageContent): Promise<T[]> { const nodes = await this.retriever.retrieve({ query: extractText(strOrQueryBundle), });