From 844029d8e58be152b138f052db8210f1400aab87 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Wed, 17 Jan 2024 14:29:27 +0700 Subject: [PATCH] feat: Add streaming support for QueryEngine (and unify streaming interface with ChatEngine) (#393) --- .changeset/spicy-rice-tan.md | 5 + .changeset/tidy-mayflies-provide.md | 5 + .../docs/modules/high_level/chat_engine.md | 11 +- .../docs/modules/high_level/query_engine.md | 11 +- examples/astradb/query.ts | 4 +- examples/chatEngine.ts | 7 +- examples/chromadb/test.ts | 6 +- examples/huggingface.ts | 11 +- examples/keywordIndex.ts | 6 +- examples/markdown.ts | 4 +- examples/mistral.ts | 2 +- examples/mongo.ts | 2 +- examples/mongodb/3_query.ts | 6 +- examples/multimodal/rag.ts | 6 +- examples/pg-vector-store/query.ts | 2 +- examples/pinecone-vector-store/query.ts | 2 +- examples/readers/load-assemblyai.ts | 2 +- examples/readers/load-csv.ts | 6 +- examples/readers/load-docx.ts | 2 +- examples/readers/load-html.ts | 6 +- examples/readers/load-md.ts | 2 +- examples/readers/load-notion.ts | 2 +- examples/readers/load-pdf.ts | 4 +- examples/sentenceWindow.ts | 6 +- examples/storageContext.ts | 12 +- examples/subquestion.ts | 6 +- examples/summaryIndex.ts | 6 +- examples/together-ai/vector-index.ts | 6 +- examples/vectorIndex.ts | 6 +- examples/vectorIndexAnthropic.ts | 6 +- examples/vectorIndexCustomize.ts | 6 +- examples/vectorIndexFromVectorStore.ts | 4 +- examples/vectorIndexGPT4.ts | 6 +- packages/core/src/ChatEngine.ts | 298 ++++++++---------- packages/core/src/QueryEngine.ts | 65 +++- packages/core/src/llm/LLM.ts | 16 +- packages/core/src/llm/utils.ts | 18 +- .../core/src/tests/CallbackManager.test.ts | 4 +- 38 files changed, 324 insertions(+), 255 deletions(-) create mode 100644 .changeset/spicy-rice-tan.md create mode 100644 .changeset/tidy-mayflies-provide.md diff --git a/.changeset/spicy-rice-tan.md b/.changeset/spicy-rice-tan.md new file mode 100644 index 000000000..0a44da675 --- /dev/null +++ b/.changeset/spicy-rice-tan.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add streaming support for QueryEngine (and unify streaming interface with ChatEngine) diff --git a/.changeset/tidy-mayflies-provide.md b/.changeset/tidy-mayflies-provide.md new file mode 100644 index 000000000..f2a2d3202 --- /dev/null +++ b/.changeset/tidy-mayflies-provide.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Breaking: Use parameter object for query and chat methods of ChatEngine and QueryEngine diff --git a/apps/docs/docs/modules/high_level/chat_engine.md b/apps/docs/docs/modules/high_level/chat_engine.md index 984e45251..41ce90720 100644 --- a/apps/docs/docs/modules/high_level/chat_engine.md +++ b/apps/docs/docs/modules/high_level/chat_engine.md @@ -11,7 +11,16 @@ const retriever = index.asRetriever(); const chatEngine = new ContextChatEngine({ retriever }); // start chatting -const response = await chatEngine.chat(query); +const response = await chatEngine.chat({ message: query }); +``` + +The `chat` function also supports streaming, just add `stream: true` as an option: + +```typescript +const stream = await chatEngine.chat({ message: query, stream: true }); +for await (const chunk of stream) { + process.stdout.write(chunk.response); +} ``` ## Api References diff --git a/apps/docs/docs/modules/high_level/query_engine.md b/apps/docs/docs/modules/high_level/query_engine.md index 3a02d0e4a..853641a17 100644 --- a/apps/docs/docs/modules/high_level/query_engine.md +++ b/apps/docs/docs/modules/high_level/query_engine.md @@ -8,7 +8,16 @@ A query engine wraps a `Retriever` and a `ResponseSynthesizer` into a pipeline, ```typescript const queryEngine = index.asQueryEngine(); -const response = await queryEngine.query("query string"); +const response = await queryEngine.query({ query: "query string" }); +``` + +The `query` function also supports streaming, just add `stream: true` as an option: + +```typescript +const stream = await queryEngine.query({ query: "query string", stream: true }); +for await (const chunk of stream) { + process.stdout.write(chunk.response); +} ``` ## Sub Question Query Engine diff --git a/examples/astradb/query.ts b/examples/astradb/query.ts index bbc0712d5..949cfd118 100644 --- a/examples/astradb/query.ts +++ b/examples/astradb/query.ts @@ -18,7 +18,9 @@ async function main() { const queryEngine = await index.asQueryEngine({ retriever }); - const results = await queryEngine.query("What is the best reviewed movie?"); + const results = await queryEngine.query({ + query: "What is the best reviewed movie?", + }); console.log(results.response); } catch (e) { diff --git a/examples/chatEngine.ts b/examples/chatEngine.ts index 728c6287d..52538944b 100644 --- a/examples/chatEngine.ts +++ b/examples/chatEngine.ts @@ -25,8 +25,11 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const response = await chatEngine.chat(query); - console.log(response.toString()); + const stream = await chatEngine.chat({ message: query, stream: true }); + console.log(); + for await (const chunk of stream) { + process.stdout.write(chunk.response); + } } } diff --git a/examples/chromadb/test.ts b/examples/chromadb/test.ts index 51d2b2692..c901e82b3 100644 --- a/examples/chromadb/test.ts +++ b/examples/chromadb/test.ts @@ -28,9 +28,9 @@ async function main() { console.log("Querying index"); const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "Tell me about Godfrey Cheshire's rating of La Sapienza.", - ); + const response = await queryEngine.query({ + query: "Tell me about Godfrey Cheshire's rating of La Sapienza.", + }); console.log(response.toString()); } catch (e) { console.error(e); diff --git a/examples/huggingface.ts b/examples/huggingface.ts index 1d02b43ab..c1e54e05f 100644 --- a/examples/huggingface.ts +++ b/examples/huggingface.ts @@ -32,12 +32,15 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const stream = await queryEngine.query({ + query: "What did the author do in college?", + stream: true, + }); // Output response - console.log(response.toString()); + for await (const chunk of stream) { + process.stdout.write(chunk.response); + } } main().catch(console.error); diff --git a/examples/keywordIndex.ts b/examples/keywordIndex.ts index 4ee2cf4fa..ad6c7b1d3 100644 --- a/examples/keywordIndex.ts +++ b/examples/keywordIndex.ts @@ -20,9 +20,9 @@ async function main() { mode, }), }); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.toString()); }); } diff --git a/examples/markdown.ts b/examples/markdown.ts index 01faf6b27..b2767713e 100644 --- a/examples/markdown.ts +++ b/examples/markdown.ts @@ -11,7 +11,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query("What does the example code do?"); + const response = await queryEngine.query({ + query: "What does the example code do?", + }); // Output response console.log(response.toString()); diff --git a/examples/mistral.ts b/examples/mistral.ts index 2047f3463..75b8b3c1a 100644 --- a/examples/mistral.ts +++ b/examples/mistral.ts @@ -27,7 +27,7 @@ async function rag(llm: LLM, embedModel: BaseEmbedding, query: string) { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); return response.response; } diff --git a/examples/mongo.ts b/examples/mongo.ts index 5eb07a4db..6cabbfe35 100644 --- a/examples/mongo.ts +++ b/examples/mongo.ts @@ -54,7 +54,7 @@ async function main() { break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); // Output response console.log(response.toString()); diff --git a/examples/mongodb/3_query.ts b/examples/mongodb/3_query.ts index 32a352939..84920dd9d 100644 --- a/examples/mongodb/3_query.ts +++ b/examples/mongodb/3_query.ts @@ -24,9 +24,9 @@ async function query() { const retriever = index.asRetriever({ similarityTopK: 20 }); const queryEngine = index.asQueryEngine({ retriever }); - const result = await queryEngine.query( - "What does the author think of web frameworks?", - ); + const result = await queryEngine.query({ + query: "What does the author think of web frameworks?", + }); console.log(result.response); await client.close(); } diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index f0f775aaa..fec645265 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -46,9 +46,9 @@ async function main() { responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }), retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }), }); - const result = await queryEngine.query( - "Tell me more about Vincent van Gogh's famous paintings", - ); + const result = await queryEngine.query({ + query: "Tell me more about Vincent van Gogh's famous paintings", + }); console.log(result.response, "\n"); images.forEach((image) => console.log(`Image retrieved and used in inference: ${image.toString()}`), diff --git a/examples/pg-vector-store/query.ts b/examples/pg-vector-store/query.ts index 77545bab8..c46f46c9e 100755 --- a/examples/pg-vector-store/query.ts +++ b/examples/pg-vector-store/query.ts @@ -31,7 +31,7 @@ async function main() { } try { - const answer = await queryEngine.query(question); + const answer = await queryEngine.query({ query: question }); console.log(answer.response); } catch (error) { console.error("Error:", error); diff --git a/examples/pinecone-vector-store/query.ts b/examples/pinecone-vector-store/query.ts index 41ca987b0..f0ee4b3c5 100755 --- a/examples/pinecone-vector-store/query.ts +++ b/examples/pinecone-vector-store/query.ts @@ -29,7 +29,7 @@ async function main() { } try { - const answer = await queryEngine.query(question); + const answer = await queryEngine.query({ query: question }); console.log(answer.response); } catch (error) { console.error("Error:", error); diff --git a/examples/readers/load-assemblyai.ts b/examples/readers/load-assemblyai.ts index 986788196..42c944059 100644 --- a/examples/readers/load-assemblyai.ts +++ b/examples/readers/load-assemblyai.ts @@ -48,7 +48,7 @@ program break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); console.log(response.toString()); } diff --git a/examples/readers/load-csv.ts b/examples/readers/load-csv.ts index 1e0a11237..d16945197 100644 --- a/examples/readers/load-csv.ts +++ b/examples/readers/load-csv.ts @@ -38,9 +38,9 @@ Given the CSV file, generate me Typescript code to answer the question: ${query} const queryEngine = index.asQueryEngine({ responseSynthesizer }); // Query the index - const response = await queryEngine.query( - "What is the correlation between survival and age?", - ); + const response = await queryEngine.query({ + query: "What is the correlation between survival and age?", + }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-docx.ts b/examples/readers/load-docx.ts index 61a8b314a..459dad3ef 100644 --- a/examples/readers/load-docx.ts +++ b/examples/readers/load-docx.ts @@ -15,7 +15,7 @@ async function main() { // Test query const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(SAMPLE_QUERY); + const response = await queryEngine.query({ query: SAMPLE_QUERY }); console.log(`Test query > ${SAMPLE_QUERY}:\n`, response.toString()); } diff --git a/examples/readers/load-html.ts b/examples/readers/load-html.ts index 766729861..87ea89ec6 100644 --- a/examples/readers/load-html.ts +++ b/examples/readers/load-html.ts @@ -10,9 +10,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What were the notable changes in 18.1?", - ); + const response = await queryEngine.query({ + query: "What were the notable changes in 18.1?", + }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-md.ts b/examples/readers/load-md.ts index bebc7a2ec..5e6e300af 100644 --- a/examples/readers/load-md.ts +++ b/examples/readers/load-md.ts @@ -15,7 +15,7 @@ async function main() { // Test query const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(SAMPLE_QUERY); + const response = await queryEngine.query({ query: SAMPLE_QUERY }); console.log(`Test query > ${SAMPLE_QUERY}:\n`, response.toString()); } diff --git a/examples/readers/load-notion.ts b/examples/readers/load-notion.ts index 7bbe065f4..04f1651ef 100644 --- a/examples/readers/load-notion.ts +++ b/examples/readers/load-notion.ts @@ -79,7 +79,7 @@ program break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-pdf.ts b/examples/readers/load-pdf.ts index afbe8f448..0255d908f 100644 --- a/examples/readers/load-pdf.ts +++ b/examples/readers/load-pdf.ts @@ -10,7 +10,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query("What mistakes did they make?"); + const response = await queryEngine.query({ + query: "What mistakes did they make?", + }); // Output response console.log(response.toString()); diff --git a/examples/sentenceWindow.ts b/examples/sentenceWindow.ts index d4c9cf182..fcb89d99d 100644 --- a/examples/sentenceWindow.ts +++ b/examples/sentenceWindow.ts @@ -31,9 +31,9 @@ async function main() { const queryEngine = index.asQueryEngine({ nodePostprocessors: [new MetadataReplacementPostProcessor("window")], }); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/storageContext.ts b/examples/storageContext.ts index 7326bbdba..74fbd43fb 100644 --- a/examples/storageContext.ts +++ b/examples/storageContext.ts @@ -20,9 +20,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); @@ -35,9 +35,9 @@ async function main() { storageContext: secondStorageContext, }); const loadedQueryEngine = loadedIndex.asQueryEngine(); - const loadedResponse = await loadedQueryEngine.query( - "What did the author do growing up?", - ); + const loadedResponse = await loadedQueryEngine.query({ + query: "What did the author do growing up?", + }); console.log(loadedResponse.toString()); } diff --git a/examples/subquestion.ts b/examples/subquestion.ts index 6d46d91b5..b1e692e1f 100644 --- a/examples/subquestion.ts +++ b/examples/subquestion.ts @@ -18,9 +18,9 @@ import essay from "./essay"; ], }); - const response = await queryEngine.query( - "How was Paul Grahams life different before and after YC?", - ); + const response = await queryEngine.query({ + query: "How was Paul Grahams life different before and after YC?", + }); console.log(response.toString()); })(); diff --git a/examples/summaryIndex.ts b/examples/summaryIndex.ts index cc34e95d3..d11a47031 100644 --- a/examples/summaryIndex.ts +++ b/examples/summaryIndex.ts @@ -20,9 +20,9 @@ async function main() { const queryEngine = index.asQueryEngine({ retriever: index.asRetriever({ mode: SummaryRetrieverMode.LLM }), }); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.toString()); } diff --git a/examples/together-ai/vector-index.ts b/examples/together-ai/vector-index.ts index fc460ca81..94b5c5762 100644 --- a/examples/together-ai/vector-index.ts +++ b/examples/together-ai/vector-index.ts @@ -29,9 +29,9 @@ async function main() { const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); console.log(response.toString()); } diff --git a/examples/vectorIndex.ts b/examples/vectorIndex.ts index ff2691624..afa95e540 100644 --- a/examples/vectorIndex.ts +++ b/examples/vectorIndex.ts @@ -16,9 +16,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/vectorIndexAnthropic.ts b/examples/vectorIndexAnthropic.ts index c979900b2..c04c516c7 100644 --- a/examples/vectorIndexAnthropic.ts +++ b/examples/vectorIndexAnthropic.ts @@ -35,9 +35,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine({ responseSynthesizer }); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts index b24e91416..e9013a0e6 100644 --- a/examples/vectorIndexCustomize.ts +++ b/examples/vectorIndexCustomize.ts @@ -33,9 +33,9 @@ async function main() { [nodePostprocessor], ); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.response); } diff --git a/examples/vectorIndexFromVectorStore.ts b/examples/vectorIndexFromVectorStore.ts index 311bc8c72..d539ae2ad 100644 --- a/examples/vectorIndexFromVectorStore.ts +++ b/examples/vectorIndexFromVectorStore.ts @@ -189,7 +189,9 @@ async function main() { }, }); - const response = await queryEngine.query("How many results do you have?"); + const response = await queryEngine.query({ + query: "How many results do you have?", + }); console.log(response.toString()); } diff --git a/examples/vectorIndexGPT4.ts b/examples/vectorIndexGPT4.ts index 6a7516dd5..ed1ed20fc 100644 --- a/examples/vectorIndexGPT4.ts +++ b/examples/vectorIndexGPT4.ts @@ -25,9 +25,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index c18609080..2f18a15a4 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,5 +1,5 @@ import { randomUUID } from "node:crypto"; -import { ChatHistory } from "./ChatHistory"; +import { ChatHistory, SimpleChatHistory } from "./ChatHistory"; import { NodeWithScore, TextNode } from "./Node"; import { CondenseQuestionPrompt, @@ -13,27 +13,40 @@ import { Response } from "./Response"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { Event } from "./callbacks/CallbackManager"; -import { ChatMessage, LLM, OpenAI } from "./llm"; +import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "./llm"; +import { streamConverter, streamReducer } from "./llm/utils"; import { BaseNodePostprocessor } from "./postprocessors"; +/** + * Represents the base parameters for ChatEngine. + */ +export interface ChatEngineParamsBase { + message: MessageContent; + /** + * Optional chat history if you want to customize the chat history. + */ + chatHistory?: ChatMessage[]; + history?: ChatHistory; +} + +export interface ChatEngineParamsStreaming extends ChatEngineParamsBase { + stream: true; +} + +export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { + stream?: false | null; +} + /** * A ChatEngine is used to handle back and forth chats between the application and the LLM. */ export interface ChatEngine { /** * Send message along with the class's current chat history to the LLM. - * @param message - * @param chatHistory optional chat history if you want to customize the chat history - * @param streaming optional streaming flag, which auto-sets the return value if True. + * @param params */ - chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[], - streaming?: T, - ): Promise<R>; + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; /** * Resets the chat history so that it's empty. @@ -53,48 +66,37 @@ export class SimpleChatEngine implements ChatEngine { this.llm = init?.llm ?? new OpenAI(); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[], - streaming?: T, - ): Promise<R> { - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - - //Non-streaming option - chatHistory = chatHistory ?? this.chatHistory; - chatHistory.push({ content: message, role: "user" }); - const response = await this.llm.chat({ messages: chatHistory }); - chatHistory.push(response.message); - this.chatHistory = chatHistory; - return new Response(response.message.content) as R; - } + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; - protected async *streamChat( - message: MessageContent, - chatHistory?: ChatMessage[], - ): AsyncGenerator<string, void, unknown> { - chatHistory = chatHistory ?? this.chatHistory; + const chatHistory = params.chatHistory ?? this.chatHistory; chatHistory.push({ content: message, role: "user" }); - const response_generator = await this.llm.chat({ - messages: chatHistory, - stream: true, - }); - var accumulator: string = ""; - for await (const part of response_generator) { - accumulator += part.delta; - yield part.delta; + if (stream) { + const stream = await this.llm.chat({ + messages: chatHistory, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); } - chatHistory.push({ content: accumulator, role: "assistant" }); - this.chatHistory = chatHistory; - return; + const response = await this.llm.chat({ messages: chatHistory }); + chatHistory.push(response.message); + return new Response(response.message.content); } reset() { @@ -115,7 +117,7 @@ export class SimpleChatEngine implements ChatEngine { export class CondenseQuestionChatEngine implements ChatEngine { queryEngine: BaseQueryEngine; chatHistory: ChatMessage[]; - serviceContext: ServiceContext; + llm: LLM; condenseMessagePrompt: CondenseQuestionPrompt; constructor(init: { @@ -126,8 +128,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { }) { this.queryEngine = init.queryEngine; this.chatHistory = init?.chatHistory ?? []; - this.serviceContext = - init?.serviceContext ?? serviceContextFromDefaults({}); + this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; this.condenseMessagePrompt = init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } @@ -135,7 +136,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { private async condenseQuestion(chatHistory: ChatMessage[], question: string) { const chatHistoryStr = messagesToHistoryStr(chatHistory); - return this.serviceContext.llm.complete({ + return this.llm.complete({ prompt: defaultCondenseQuestionPrompt({ question: question, chatHistory: chatHistoryStr, @@ -143,26 +144,39 @@ export class CondenseQuestionChatEngine implements ChatEngine { }); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - streaming?: T, - ): Promise<R> { - chatHistory = chatHistory ?? this.chatHistory; + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory ?? this.chatHistory; const condensedQuestion = ( await this.condenseQuestion(chatHistory, extractText(message)) ).text; - - const response = await this.queryEngine.query(condensedQuestion); - chatHistory.push({ content: message, role: "user" }); + + if (stream) { + const stream = await this.queryEngine.query({ + query: condensedQuestion, + stream: true, + }); + return streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.response), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }); + } + const response = await this.queryEngine.query({ + query: condensedQuestion, + }); chatHistory.push({ content: response.response, role: "assistant" }); - return response as R; + return response; } reset() { @@ -255,21 +269,13 @@ export class ContextChatEngine implements ChatEngine { }); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - streaming?: T, - ): Promise<R> { - chatHistory = chatHistory ?? this.chatHistory; - - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory ?? this.chatHistory; const parentEvent: Event = { id: randomUUID(), type: "wrapper", @@ -279,57 +285,33 @@ export class ContextChatEngine implements ChatEngine { extractText(message), parentEvent, ); - + const nodes = context.nodes.map((r) => r.node); chatHistory.push({ content: message, role: "user" }); + if (stream) { + const stream = await this.chatModel.chat({ + messages: [context.message, ...chatHistory], + parentEvent, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta, nodes), + ); + } const response = await this.chatModel.chat({ messages: [context.message, ...chatHistory], parentEvent, }); chatHistory.push(response.message); - - this.chatHistory = chatHistory; - - return new Response( - response.message.content, - context.nodes.map((r) => r.node), - ) as R; - } - - protected async *streamChat( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - ): AsyncGenerator<string, void, unknown> { - chatHistory = chatHistory ?? this.chatHistory; - - const parentEvent: Event = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - const context = await this.contextGenerator.generate( - extractText(message), - parentEvent, - ); - - chatHistory.push({ content: message, role: "user" }); - - const response_stream = await this.chatModel.chat({ - messages: [context.message, ...chatHistory], - parentEvent, - stream: true, - }); - var accumulator: string = ""; - for await (const part of response_stream) { - accumulator += part.delta; - yield part.delta; - } - - chatHistory.push({ content: accumulator, role: "assistant" }); - - this.chatHistory = chatHistory; - - return; + return new Response(response.message.content, nodes); } reset() { @@ -382,50 +364,38 @@ export class HistoryChatEngine { this.contextGenerator = init?.contextGenerator; } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory: ChatHistory, - streaming?: T, - ): Promise<R> { - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - const requestMessages = await this.prepareRequestMessages( - message, - chatHistory, - ); - const response = await this.llm.chat({ messages: requestMessages }); - chatHistory.addMessage(response.message); - return new Response(response.message.content) as R; - } - - protected async *streamChat( - message: MessageContent, - chatHistory: ChatHistory, - ): AsyncGenerator<string, void, unknown> { + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream, history } = params; + const chatHistory = history ?? new SimpleChatHistory(); const requestMessages = await this.prepareRequestMessages( message, chatHistory, ); - const response_stream = await this.llm.chat({ - messages: requestMessages, - stream: true, - }); - var accumulator = ""; - for await (const part of response_stream) { - accumulator += part.delta; - yield part.delta; + if (stream) { + const stream = await this.llm.chat({ + messages: requestMessages, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.addMessage({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); } - chatHistory.addMessage({ - content: accumulator, - role: "assistant", - }); - return; + const response = await this.llm.chat({ messages: requestMessages }); + chatHistory.addMessage(response.message); + return new Response(response.message.content); } private async prepareRequestMessages( diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index fb059459f..0c709504b 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -17,16 +17,32 @@ import { ResponseSynthesizer, } from "./synthesizers"; +/** + * Parameters for sending a query. + */ +export interface QueryEngineParamsBase { + query: string; + parentEvent?: Event; +} + +export interface QueryEngineParamsStreaming extends QueryEngineParamsBase { + stream: true; +} + +export interface QueryEngineParamsNonStreaming extends QueryEngineParamsBase { + stream?: false | null; +} + /** * A query engine is a question answerer that can use one or more steps. */ export interface BaseQueryEngine { /** * Query the query engine and get a response. - * @param query - * @param parentEvent + * @param params */ - query(query: string, parentEvent?: Event): Promise<Response>; + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; } /** @@ -70,17 +86,30 @@ export class RetrieverQueryEngine implements BaseQueryEngine { return this.applyNodePostprocessors(nodes); } - async query(query: string, parentEvent?: Event) { - const _parentEvent: Event = parentEvent || { + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; + const parentEvent: Event = params.parentEvent || { id: randomUUID(), type: "wrapper", tags: ["final"], }; - const nodesWithScore = await this.retrieve(query, _parentEvent); + const nodesWithScore = await this.retrieve(query, parentEvent); + if (stream) { + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + stream: true, + }); + } return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent: _parentEvent, + parentEvent, }); } } @@ -135,11 +164,16 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { }); } - async query(query: string): Promise<Response> { + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; const subQuestions = await this.questionGen.generate(this.metadatas, query); // groups final retrieval+synthesis operation - const parentEvent: Event = { + const parentEvent: Event = params.parentEvent || { id: randomUUID(), type: "wrapper", tags: ["final"], @@ -160,6 +194,14 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const nodesWithScore = subQNodes .filter((node) => node !== null) .map((node) => node as NodeWithScore); + if (stream) { + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + stream: true, + }); + } return this.responseSynthesizer.synthesize({ query, nodesWithScore, @@ -175,7 +217,10 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const question = subQ.subQuestion; const queryEngine = this.queryEngines[subQ.toolName]; - const response = await queryEngine.query(question, parentEvent); + const response = await queryEngine.query({ + query: question, + parentEvent, + }); const responseText = response.response; const nodeText = `Sub question: ${question}\nResponse: ${responseText}`; const node = new TextNode({ text: nodeText }); diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index a57435cce..e6ee82e4d 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -27,6 +27,7 @@ import { import { OpenAISession, getOpenAISession } from "./openai"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate"; +import { streamConverter } from "./utils"; export type MessageType = | "user" @@ -127,14 +128,6 @@ export interface LLM { export abstract class BaseLLM implements LLM { abstract metadata: LLMMetadata; - private async *chatToComplete( - stream: AsyncIterable<ChatResponseChunk>, - ): AsyncIterable<CompletionResponse> { - for await (const chunk of stream) { - yield { text: chunk.delta }; - } - } - complete( params: LLMCompletionParamsStreaming, ): Promise<AsyncIterable<CompletionResponse>>; @@ -151,7 +144,11 @@ export abstract class BaseLLM implements LLM { parentEvent, stream: true, }); - return this.chatToComplete(stream); + return streamConverter(stream, (chunk) => { + return { + text: chunk.delta, + }; + }); } const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], @@ -392,6 +389,7 @@ export class OpenAI extends BaseLLM { type: "llmPredict" as EventType, }; + // TODO: add callback to streamConverter and use streamConverter here //Indices var idx_counter: number = 0; for await (const part of chunk_stream) { diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 3a693a9b7..55fc5f78d 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,5 +1,3 @@ -// TODO: use for LLM.ts - export async function* streamConverter<S, D>( stream: AsyncIterable<S>, converter: (s: S) => D, @@ -8,3 +6,19 @@ export async function* streamConverter<S, D>( yield converter(data); } } + +export async function* streamReducer<S, D>(params: { + stream: AsyncIterable<S>; + reducer: (previousValue: D, currentValue: S) => D; + initialValue: D; + finished?: (value: D | undefined) => void; +}): AsyncIterable<S> { + let value = params.initialValue; + for await (const data of params.stream) { + value = params.reducer(value, data); + yield data; + } + if (params.finished) { + params.finished(value); + } +} diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 38a95028b..da56efd1c 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -67,7 +67,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }); const queryEngine = vectorStoreIndex.asQueryEngine(); const query = "What is the author's name?"; - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { @@ -145,7 +145,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { responseSynthesizer, }); const query = "What is the author's name?"; - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { -- GitLab