From c856c5becb3635399fe3eeadf28488454160fde0 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Mon, 23 Sep 2024 18:35:36 -0700 Subject: [PATCH] revert: stream back to first parameter (#1247) --- .changeset/four-beers-kick.md | 7 ++--- examples/agent/azure_dynamic_session.ts | 1 + examples/agent/stream_openai_agent.ts | 10 +++---- examples/agent/wiki.ts | 10 +++---- examples/anthropic/chat_interactive.ts | 2 +- examples/chatEngine.ts | 2 +- examples/chatHistory.ts | 12 ++++----- examples/cloud/chat.ts | 2 +- examples/huggingface/embedding.ts | 10 +++---- examples/huggingface/embeddingApi.ts | 10 +++---- examples/multimodal/rag.ts | 10 +++---- .../autotool/examples/02_nextjs/actions.ts | 10 +++---- packages/core/src/chat-engine/index.ts | 22 ++++++++++----- packages/core/src/query-engine/base.ts | 27 ++++++++++++------- .../cloudflare-worker-agent/src/index.ts | 10 +++---- .../nextjs-agent/src/actions/index.tsx | 12 ++++----- packages/llamaindex/e2e/node/openai.e2e.ts | 10 +++---- packages/llamaindex/e2e/node/react.e2e.ts | 11 ++++---- packages/llamaindex/src/agent/anthropic.ts | 14 +++++++--- packages/llamaindex/src/agent/base.ts | 18 ++++++++----- .../chat/CondenseQuestionChatEngine.ts | 23 +++++++--------- .../src/engines/chat/ContextChatEngine.ts | 13 +++++---- .../src/engines/chat/SimpleChatEngine.ts | 13 +++++---- .../src/engines/query/RouterQueryEngine.ts | 4 ++- .../llamaindex/src/evaluation/Faithfulness.ts | 3 ++- 25 files changed, 134 insertions(+), 132 deletions(-) diff --git a/.changeset/four-beers-kick.md b/.changeset/four-beers-kick.md index 4e591e3c3..97250234b 100644 --- a/.changeset/four-beers-kick.md +++ b/.changeset/four-beers-kick.md @@ -1,12 +1,9 @@ --- -"@llamaindex/core": minor -"llamaindex": minor +"@llamaindex/core": patch +"llamaindex": patch --- refactor: move chat engine & retriever into core. -This is a breaking change since `stream` option has moved to second parameter. - -- `chat` API in BaseChatEngine has changed, Move `stream` option into second parameter. - `chatHistory` in BaseChatEngine now returns `ChatMessage[] | Promise<ChatMessage[]>`, instead of `BaseMemory` - update `retrieve-end` type diff --git a/examples/agent/azure_dynamic_session.ts b/examples/agent/azure_dynamic_session.ts index 72dbb77e6..31d5375c7 100644 --- a/examples/agent/azure_dynamic_session.ts +++ b/examples/agent/azure_dynamic_session.ts @@ -42,6 +42,7 @@ async function main() { const response = await agent.chat({ message: "plot a chart of 5 random numbers and save it to /mnt/data/chart.png", + stream: false, }); // Print the response diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts index 19dfcca17..4d8d6e8fc 100644 --- a/examples/agent/stream_openai_agent.ts +++ b/examples/agent/stream_openai_agent.ts @@ -61,12 +61,10 @@ async function main() { tools: [functionTool, functionTool2], }); - const stream = await agent.chat( - { - message: "Divide 16 by 2 then add 20", - }, - true, - ); + const stream = await agent.chat({ + message: "Divide 16 by 2 then add 20", + stream: true, + }); console.log("Response:"); diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index 1ec98652b..e4100e990 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -11,12 +11,10 @@ async function main() { }); // Chat with the agent - const response = await agent.chat( - { - message: "Who was Goethe?", - }, - true, - ); + const response = await agent.chat({ + message: "Who was Goethe?", + stream: true, + }); for await (const { delta } of response) { process.stdout.write(delta); diff --git a/examples/anthropic/chat_interactive.ts b/examples/anthropic/chat_interactive.ts index 9579cd7a4..3f33b268f 100644 --- a/examples/anthropic/chat_interactive.ts +++ b/examples/anthropic/chat_interactive.ts @@ -25,7 +25,7 @@ import readline from "node:readline/promises"; while (true) { const query = await rl.question("User: "); process.stdout.write("Assistant: "); - const stream = await chatEngine.chat({ message: query }, true); + const stream = await chatEngine.chat({ message: query, stream: true }); for await (const chunk of stream) { process.stdout.write(chunk.response); } diff --git a/examples/chatEngine.ts b/examples/chatEngine.ts index da24e2251..addd025bb 100644 --- a/examples/chatEngine.ts +++ b/examples/chatEngine.ts @@ -24,7 +24,7 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const stream = await chatEngine.chat({ message: query }, true); + const stream = await chatEngine.chat({ message: query, stream: true }); console.log(); for await (const chunk of stream) { process.stdout.write(chunk.response); diff --git a/examples/chatHistory.ts b/examples/chatHistory.ts index 14e12a31b..c55c618d6 100644 --- a/examples/chatHistory.ts +++ b/examples/chatHistory.ts @@ -24,13 +24,11 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const stream = await chatEngine.chat( - { - message: query, - chatHistory, - }, - true, - ); + const stream = await chatEngine.chat({ + message: query, + chatHistory, + stream: true, + }); if (chatHistory.getLastSummary()) { // Print the summary of the conversation so far that is produced by the SummaryChatHistory console.log(`Summary: ${chatHistory.getLastSummary()?.content}`); diff --git a/examples/cloud/chat.ts b/examples/cloud/chat.ts index c39224ae0..d6fdce572 100644 --- a/examples/cloud/chat.ts +++ b/examples/cloud/chat.ts @@ -18,7 +18,7 @@ async function main() { while (true) { const query = await rl.question("User: "); - const stream = await chatEngine.chat({ message: query }, true); + const stream = await chatEngine.chat({ message: query, stream: true }); for await (const chunk of stream) { process.stdout.write(chunk.response); } diff --git a/examples/huggingface/embedding.ts b/examples/huggingface/embedding.ts index 01aca316a..8297b7536 100644 --- a/examples/huggingface/embedding.ts +++ b/examples/huggingface/embedding.ts @@ -27,12 +27,10 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const stream = await queryEngine.query( - { - query: "What did the author do in college?", - }, - true, - ); + const stream = await queryEngine.query({ + query: "What did the author do in college?", + stream: true, + }); // Output response for await (const chunk of stream) { diff --git a/examples/huggingface/embeddingApi.ts b/examples/huggingface/embeddingApi.ts index a0bc861cb..a89df2703 100644 --- a/examples/huggingface/embeddingApi.ts +++ b/examples/huggingface/embeddingApi.ts @@ -37,12 +37,10 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const stream = await queryEngine.query( - { - query: "What did the author do in college?", - }, - true, - ); + const stream = await queryEngine.query({ + query: "What did the author do in college?", + stream: true, + }); // Output response for await (const chunk of stream) { diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index bbe8a6296..14d3a1c74 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -32,12 +32,10 @@ async function main() { responseSynthesizer: getResponseSynthesizer("multi_modal"), retriever: index.asRetriever({ topK: { TEXT: 3, IMAGE: 1 } }), }); - const stream = await queryEngine.query( - { - query: "Tell me more about Vincent van Gogh's famous paintings", - }, - true, - ); + const stream = await queryEngine.query({ + query: "Tell me more about Vincent van Gogh's famous paintings", + stream: true, + }); for await (const chunk of stream) { process.stdout.write(chunk.response); } diff --git a/packages/autotool/examples/02_nextjs/actions.ts b/packages/autotool/examples/02_nextjs/actions.ts index b6738a230..3c5d12a48 100644 --- a/packages/autotool/examples/02_nextjs/actions.ts +++ b/packages/autotool/examples/02_nextjs/actions.ts @@ -14,12 +14,10 @@ export async function chatWithAI(message: string): Promise<ReactNode> { const uiStream = createStreamableUI(); runWithStreamableUI(uiStream, () => agent - .chat( - { - message, - }, - true, - ) + .chat({ + stream: true, + message, + }) .then(async (responseStream) => { return responseStream.pipeTo( new WritableStream({ diff --git a/packages/core/src/chat-engine/index.ts b/packages/core/src/chat-engine/index.ts index 07d7b1223..b4bd4cf3b 100644 --- a/packages/core/src/chat-engine/index.ts +++ b/packages/core/src/chat-engine/index.ts @@ -2,7 +2,7 @@ import type { ChatMessage, MessageContent } from "../llms"; import type { BaseMemory } from "../memory"; import { EngineResponse } from "../schema"; -export interface ChatEngineParams< +export interface BaseChatEngineParams< AdditionalMessageOptions extends object = object, > { message: MessageContent; @@ -14,14 +14,22 @@ export interface ChatEngineParams< | BaseMemory<AdditionalMessageOptions>; } +export interface StreamingChatEngineParams< + AdditionalMessageOptions extends object = object, +> extends BaseChatEngineParams<AdditionalMessageOptions> { + stream: true; +} + +export interface NonStreamingChatEngineParams< + AdditionalMessageOptions extends object = object, +> extends BaseChatEngineParams<AdditionalMessageOptions> { + stream?: false; +} + export abstract class BaseChatEngine { + abstract chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; abstract chat( - params: ChatEngineParams, - stream?: false, - ): Promise<EngineResponse>; - abstract chat( - params: ChatEngineParams, - stream: true, + params: StreamingChatEngineParams, ): Promise<AsyncIterable<EngineResponse>>; abstract chatHistory: ChatMessage[] | Promise<ChatMessage[]>; diff --git a/packages/core/src/query-engine/base.ts b/packages/core/src/query-engine/base.ts index db079b284..464c4ff2b 100644 --- a/packages/core/src/query-engine/base.ts +++ b/packages/core/src/query-engine/base.ts @@ -18,6 +18,18 @@ export type QueryBundle = { export type QueryType = string | QueryBundle; +export type BaseQueryParams = { + query: QueryType; +}; + +export interface StreamingQueryParams extends BaseQueryParams { + stream: true; +} + +export interface NonStreamingQueryParams extends BaseQueryParams { + stream?: false; +} + export type QueryFn = ( strOrQueryBundle: QueryType, stream?: boolean, @@ -34,23 +46,20 @@ export abstract class BaseQueryEngine extends PromptMixin { ); } - query( - strOrQueryBundle: QueryType, - stream: true, - ): Promise<AsyncIterable<EngineResponse>>; - query(strOrQueryBundle: QueryType, stream?: false): Promise<EngineResponse>; + query(params: StreamingQueryParams): Promise<AsyncIterable<EngineResponse>>; + query(params: NonStreamingQueryParams): Promise<EngineResponse>; @wrapEventCaller async query( - strOrQueryBundle: QueryType, - stream = false, + params: StreamingQueryParams | NonStreamingQueryParams, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { + const { stream, query } = params; const id = randomUUID(); const callbackManager = Settings.callbackManager; callbackManager.dispatchEvent("query-start", { id, - query: strOrQueryBundle, + query, }); - const response = await this._query(strOrQueryBundle, stream); + const response = await this._query(query, stream); callbackManager.dispatchEvent("query-end", { id, response, diff --git a/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts b/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts index ce80b00b2..8a283cc6b 100644 --- a/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts +++ b/packages/llamaindex/e2e/examples/cloudflare-worker-agent/src/index.ts @@ -11,12 +11,10 @@ export default { tools: [], }); console.log(1); - const responseStream = await agent.chat( - { - message: "Hello? What is the weather today?", - }, - true, - ); + const responseStream = await agent.chat({ + stream: true, + message: "Hello? What is the weather today?", + }); console.log(2); const textEncoder = new TextEncoder(); const response = responseStream.pipeThrough<Uint8Array>( diff --git a/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx b/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx index f648a8d9d..b71e52a92 100644 --- a/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx +++ b/packages/llamaindex/e2e/examples/nextjs-agent/src/actions/index.tsx @@ -10,13 +10,11 @@ export async function chatWithAgent( const agent = new OpenAIAgent({ tools: [], }); - const responseStream = await agent.chat( - { - message: question, - chatHistory: prevMessages, - }, - true, - ); + const responseStream = await agent.chat({ + stream: true, + message: question, + chatHistory: prevMessages, + }); const uiStream = createStreamableUI(<div>loading...</div>); responseStream .pipeTo( diff --git a/packages/llamaindex/e2e/node/openai.e2e.ts b/packages/llamaindex/e2e/node/openai.e2e.ts index eb9918eef..401939085 100644 --- a/packages/llamaindex/e2e/node/openai.e2e.ts +++ b/packages/llamaindex/e2e/node/openai.e2e.ts @@ -322,12 +322,10 @@ await test("agent stream", async (t) => { tools: [sumNumbersTool, divideNumbersTool], }); - const stream = await agent.chat( - { - message: "Divide 16 by 2 then add 20", - }, - true, - ); + const stream = await agent.chat({ + message: "Divide 16 by 2 then add 20", + stream: true, + }); let message = ""; diff --git a/packages/llamaindex/e2e/node/react.e2e.ts b/packages/llamaindex/e2e/node/react.e2e.ts index 58170a57f..c0bb8c46c 100644 --- a/packages/llamaindex/e2e/node/react.e2e.ts +++ b/packages/llamaindex/e2e/node/react.e2e.ts @@ -20,6 +20,7 @@ await test("react agent", async (t) => { tools: [getWeatherTool], }); const response = await agent.chat({ + stream: false, message: "What is the weather like in San Francisco?", }); @@ -34,12 +35,10 @@ await test("react agent stream", async (t) => { tools: [getWeatherTool], }); - const stream = await agent.chat( - { - message: "What is the weather like in San Francisco?", - }, - true, - ); + const stream = await agent.chat({ + stream: true, + message: "What is the weather like in San Francisco?", + }); let content = ""; for await (const response of stream) { diff --git a/packages/llamaindex/src/agent/anthropic.ts b/packages/llamaindex/src/agent/anthropic.ts index e8c827b68..45d7c5ce3 100644 --- a/packages/llamaindex/src/agent/anthropic.ts +++ b/packages/llamaindex/src/agent/anthropic.ts @@ -1,4 +1,7 @@ -import type { ChatEngineParams } from "@llamaindex/core/chat-engine"; +import type { + NonStreamingChatEngineParams, + StreamingChatEngineParams, +} from "@llamaindex/core/chat-engine"; import type { EngineResponse } from "@llamaindex/core/schema"; import { Settings } from "../Settings.js"; import { Anthropic } from "../llm/anthropic.js"; @@ -21,9 +24,12 @@ export class AnthropicAgent extends LLMAgent { }); } - async chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; - async chat(params: ChatEngineParams, stream: true): Promise<never>; - override async chat(params: ChatEngineParams, stream?: boolean) { + async chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; + async chat(params: StreamingChatEngineParams): Promise<never>; + override async chat( + params: NonStreamingChatEngineParams | StreamingChatEngineParams, + ) { + const { stream } = params; if (stream) { // Anthropic does support this, but looks like it's not supported in the LITS LLM throw new Error("Anthropic does not support streaming"); diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index 8141ceac4..2cc2233ca 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -1,6 +1,7 @@ import { BaseChatEngine, - type ChatEngineParams, + type NonStreamingChatEngineParams, + type StreamingChatEngineParams, } from "@llamaindex/core/chat-engine"; import type { BaseToolWithCall, @@ -344,15 +345,13 @@ export abstract class AgentRunner< }); } - async chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; + async chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; async chat( - params: ChatEngineParams, - stream: true, + params: StreamingChatEngineParams, ): Promise<ReadableStream<EngineResponse>>; @wrapEventCaller async chat( - params: ChatEngineParams, - stream?: boolean, + params: NonStreamingChatEngineParams | StreamingChatEngineParams, ): Promise<EngineResponse | ReadableStream<EngineResponse>> { let chatHistory: ChatMessage<AdditionalMessageOptions>[] = []; @@ -364,7 +363,12 @@ export abstract class AgentRunner< params.chatHistory as ChatMessage<AdditionalMessageOptions>[]; } - const task = this.createTask(params.message, !!stream, false, chatHistory); + const task = this.createTask( + params.message, + !!params.stream, + false, + chatHistory, + ); for await (const stepOutput of task) { // update chat history for each round this.#chatHistory = [...stepOutput.taskStep.context.store.messages]; diff --git a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts index 4ebff2f40..b63ca774a 100644 --- a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts @@ -1,6 +1,7 @@ import { BaseChatEngine, - type ChatEngineParams, + type NonStreamingChatEngineParams, + type StreamingChatEngineParams, } from "@llamaindex/core/chat-engine"; import type { ChatMessage, LLM } from "@llamaindex/core/llms"; import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; @@ -86,17 +87,15 @@ export class CondenseQuestionChatEngine extends BaseChatEngine { }); } - chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; + chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; chat( - params: ChatEngineParams, - stream: true, + params: StreamingChatEngineParams, ): Promise<AsyncIterable<EngineResponse>>; @wrapEventCaller async chat( - params: ChatEngineParams, - stream = false, + params: NonStreamingChatEngineParams | StreamingChatEngineParams, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message } = params; + const { message, stream } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ chatHistory: @@ -112,12 +111,10 @@ export class CondenseQuestionChatEngine extends BaseChatEngine { chatHistory.put({ content: message, role: "user" }); if (stream) { - const stream = await this.queryEngine.query( - { - query: condensedQuestion, - }, - true, - ); + const stream = await this.queryEngine.query({ + query: condensedQuestion, + stream: true, + }); return streamReducer({ stream, initialValue: "", diff --git a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts index 26baddf9c..81e6f419f 100644 --- a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts @@ -1,6 +1,7 @@ import type { BaseChatEngine, - ChatEngineParams, + NonStreamingChatEngineParams, + StreamingChatEngineParams, } from "@llamaindex/core/chat-engine"; import type { ChatMessage, @@ -82,17 +83,15 @@ export class ContextChatEngine extends PromptMixin implements BaseChatEngine { }; } - chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; + chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; chat( - params: ChatEngineParams, - stream: true, + params: StreamingChatEngineParams, ): Promise<AsyncIterable<EngineResponse>>; @wrapEventCaller async chat( - params: ChatEngineParams, - stream = false, + params: StreamingChatEngineParams | NonStreamingChatEngineParams, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message } = params; + const { message, stream } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ chatHistory: diff --git a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts index d3dc9f5d0..dee8b5170 100644 --- a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts @@ -1,6 +1,7 @@ import type { BaseChatEngine, - ChatEngineParams, + NonStreamingChatEngineParams, + StreamingChatEngineParams, } from "@llamaindex/core/chat-engine"; import type { LLM } from "@llamaindex/core/llms"; import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; @@ -29,17 +30,15 @@ export class SimpleChatEngine implements BaseChatEngine { this.llm = init?.llm ?? Settings.llm; } - chat(params: ChatEngineParams, stream?: false): Promise<EngineResponse>; + chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; chat( - params: ChatEngineParams, - stream: true, + params: StreamingChatEngineParams, ): Promise<AsyncIterable<EngineResponse>>; @wrapEventCaller async chat( - params: ChatEngineParams, - stream = false, + params: NonStreamingChatEngineParams | StreamingChatEngineParams, ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { - const { message } = params; + const { message, stream } = params; const chatHistory = params.chatHistory ? new ChatMemoryBuffer({ diff --git a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts index 8d4f6058d..7f0a614f0 100644 --- a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts @@ -136,7 +136,9 @@ export class RouterQueryEngine extends BaseQueryEngine { } const selectedQueryEngine = this.queryEngines[engineInd.index]!; - responses.push(await selectedQueryEngine.query(query)); + responses.push( + await selectedQueryEngine.query({ query, stream: false }), + ); } if (responses.length > 1) { diff --git a/packages/llamaindex/src/evaluation/Faithfulness.ts b/packages/llamaindex/src/evaluation/Faithfulness.ts index b1e68f1f4..590e2e63d 100644 --- a/packages/llamaindex/src/evaluation/Faithfulness.ts +++ b/packages/llamaindex/src/evaluation/Faithfulness.ts @@ -103,7 +103,8 @@ export class FaithfulnessEvaluator }); const responseObj = await queryEngine.query({ - query: response, + query: { query: response }, + stream: false, }); const rawResponseTxt = responseObj.toString(); -- GitLab