diff --git a/.changeset/spicy-rice-tan.md b/.changeset/spicy-rice-tan.md new file mode 100644 index 0000000000000000000000000000000000000000..0a44da675891dea4c856fceab31fb625f0c10a86 --- /dev/null +++ b/.changeset/spicy-rice-tan.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add streaming support for QueryEngine (and unify streaming interface with ChatEngine) diff --git a/.changeset/tidy-mayflies-provide.md b/.changeset/tidy-mayflies-provide.md new file mode 100644 index 0000000000000000000000000000000000000000..f2a2d320261c133fae7b02c14b20e6c0b3163716 --- /dev/null +++ b/.changeset/tidy-mayflies-provide.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Breaking: Use parameter object for query and chat methods of ChatEngine and QueryEngine diff --git a/apps/docs/docs/modules/high_level/chat_engine.md b/apps/docs/docs/modules/high_level/chat_engine.md index 984e4525172b5c5110e98c533ed071f9d786f5ab..41ce90720b092ca1ee21c744a3fccf359037405c 100644 --- a/apps/docs/docs/modules/high_level/chat_engine.md +++ b/apps/docs/docs/modules/high_level/chat_engine.md @@ -11,7 +11,16 @@ const retriever = index.asRetriever(); const chatEngine = new ContextChatEngine({ retriever }); // start chatting -const response = await chatEngine.chat(query); +const response = await chatEngine.chat({ message: query }); +``` + +The `chat` function also supports streaming, just add `stream: true` as an option: + +```typescript +const stream = await chatEngine.chat({ message: query, stream: true }); +for await (const chunk of stream) { + process.stdout.write(chunk.response); +} ``` ## Api References diff --git a/apps/docs/docs/modules/high_level/query_engine.md b/apps/docs/docs/modules/high_level/query_engine.md index 3a02d0e4af3371c3485ed3ab41cc99249fbf21fb..853641a17ea6a0c61061c31be5f6dd2144e67bf5 100644 --- a/apps/docs/docs/modules/high_level/query_engine.md +++ b/apps/docs/docs/modules/high_level/query_engine.md @@ -8,7 +8,16 @@ A query engine wraps a `Retriever` and a `ResponseSynthesizer` into a pipeline, ```typescript const queryEngine = index.asQueryEngine(); -const response = await queryEngine.query("query string"); +const response = await queryEngine.query({ query: "query string" }); +``` + +The `query` function also supports streaming, just add `stream: true` as an option: + +```typescript +const stream = await queryEngine.query({ query: "query string", stream: true }); +for await (const chunk of stream) { + process.stdout.write(chunk.response); +} ``` ## Sub Question Query Engine diff --git a/examples/astradb/query.ts b/examples/astradb/query.ts index bbc0712d537eaa7bbaeba82967fb89f5ab0ff4b3..949cfd1188b7d10fab7276639a8b1f8a3574ad90 100644 --- a/examples/astradb/query.ts +++ b/examples/astradb/query.ts @@ -18,7 +18,9 @@ async function main() { const queryEngine = await index.asQueryEngine({ retriever }); - const results = await queryEngine.query("What is the best reviewed movie?"); + const results = await queryEngine.query({ + query: "What is the best reviewed movie?", + }); console.log(results.response); } catch (e) { diff --git a/examples/chatEngine.ts b/examples/chatEngine.ts index 728c6287d6e30f11794199b2cb55bf09615b46b9..52538944b5f95c9f6de426841590d5e0599a9c4a 100644 --- a/examples/chatEngine.ts +++ b/examples/chatEngine.ts @@ -25,8 +25,11 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const response = await chatEngine.chat(query); - console.log(response.toString()); + const stream = await chatEngine.chat({ message: query, stream: true }); + console.log(); + for await (const chunk of stream) { + process.stdout.write(chunk.response); + } } } diff --git a/examples/chromadb/test.ts b/examples/chromadb/test.ts index 51d2b2692f9025f12d37b027a4a915a82ae29787..c901e82b36bb127183f0e98476b4c93e781e96ce 100644 --- a/examples/chromadb/test.ts +++ b/examples/chromadb/test.ts @@ -28,9 +28,9 @@ async function main() { console.log("Querying index"); const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "Tell me about Godfrey Cheshire's rating of La Sapienza.", - ); + const response = await queryEngine.query({ + query: "Tell me about Godfrey Cheshire's rating of La Sapienza.", + }); console.log(response.toString()); } catch (e) { console.error(e); diff --git a/examples/huggingface.ts b/examples/huggingface.ts index 1d02b43ab98f2486dffc5fe94b14fcc8128aa656..c1e54e05f53d3d1f972c0f85b78d029d92f7170c 100644 --- a/examples/huggingface.ts +++ b/examples/huggingface.ts @@ -32,12 +32,15 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const stream = await queryEngine.query({ + query: "What did the author do in college?", + stream: true, + }); // Output response - console.log(response.toString()); + for await (const chunk of stream) { + process.stdout.write(chunk.response); + } } main().catch(console.error); diff --git a/examples/keywordIndex.ts b/examples/keywordIndex.ts index 4ee2cf4fae0a84d3d9bb530f4fb36120bb33dc40..ad6c7b1d3642083fd67be9b6790c68d164a85f8f 100644 --- a/examples/keywordIndex.ts +++ b/examples/keywordIndex.ts @@ -20,9 +20,9 @@ async function main() { mode, }), }); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.toString()); }); } diff --git a/examples/markdown.ts b/examples/markdown.ts index 01faf6b27aea4664e3de379a89dbf7d4e8f46d57..b2767713e6b8eb1989f27a9799dd224b17b0b468 100644 --- a/examples/markdown.ts +++ b/examples/markdown.ts @@ -11,7 +11,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query("What does the example code do?"); + const response = await queryEngine.query({ + query: "What does the example code do?", + }); // Output response console.log(response.toString()); diff --git a/examples/mistral.ts b/examples/mistral.ts index 2047f34636b49d9b77ac04f1f079057d645fde47..75b8b3c1a052e86a3828a851629fab825b6b95d6 100644 --- a/examples/mistral.ts +++ b/examples/mistral.ts @@ -27,7 +27,7 @@ async function rag(llm: LLM, embedModel: BaseEmbedding, query: string) { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); return response.response; } diff --git a/examples/mongo.ts b/examples/mongo.ts index 5eb07a4db4d2ce5f34cc7c7f3615a55d7fbd8e9d..6cabbfe35abbf2c74ca69ba66e86f4aeb53cbd10 100644 --- a/examples/mongo.ts +++ b/examples/mongo.ts @@ -54,7 +54,7 @@ async function main() { break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); // Output response console.log(response.toString()); diff --git a/examples/mongodb/3_query.ts b/examples/mongodb/3_query.ts index 32a35293942a99230427c2ea19026db6984244dd..84920dd9d238027055083bbed862405ac688a80d 100644 --- a/examples/mongodb/3_query.ts +++ b/examples/mongodb/3_query.ts @@ -24,9 +24,9 @@ async function query() { const retriever = index.asRetriever({ similarityTopK: 20 }); const queryEngine = index.asQueryEngine({ retriever }); - const result = await queryEngine.query( - "What does the author think of web frameworks?", - ); + const result = await queryEngine.query({ + query: "What does the author think of web frameworks?", + }); console.log(result.response); await client.close(); } diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index f0f775aaaa5e1368e3ff91a25e30c752495b60f9..fec645265159af6807eaccf735ccee9d7f36288c 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -46,9 +46,9 @@ async function main() { responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }), retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }), }); - const result = await queryEngine.query( - "Tell me more about Vincent van Gogh's famous paintings", - ); + const result = await queryEngine.query({ + query: "Tell me more about Vincent van Gogh's famous paintings", + }); console.log(result.response, "\n"); images.forEach((image) => console.log(`Image retrieved and used in inference: ${image.toString()}`), diff --git a/examples/pg-vector-store/query.ts b/examples/pg-vector-store/query.ts index 77545bab860b09ce1ca97830d4ecf5c89ab1dd00..c46f46c9e0ad31d7fb0d9f3b35f5a196395b8467 100755 --- a/examples/pg-vector-store/query.ts +++ b/examples/pg-vector-store/query.ts @@ -31,7 +31,7 @@ async function main() { } try { - const answer = await queryEngine.query(question); + const answer = await queryEngine.query({ query: question }); console.log(answer.response); } catch (error) { console.error("Error:", error); diff --git a/examples/pinecone-vector-store/query.ts b/examples/pinecone-vector-store/query.ts index 41ca987b05d9a2bafc053c57c014b87438fd28d1..f0ee4b3c57a861299a84e2541b36ebd25e994330 100755 --- a/examples/pinecone-vector-store/query.ts +++ b/examples/pinecone-vector-store/query.ts @@ -29,7 +29,7 @@ async function main() { } try { - const answer = await queryEngine.query(question); + const answer = await queryEngine.query({ query: question }); console.log(answer.response); } catch (error) { console.error("Error:", error); diff --git a/examples/readers/load-assemblyai.ts b/examples/readers/load-assemblyai.ts index 986788196ab50b946827fc61bb40d03dbf80607b..42c944059baf4bf67fae8b413d72bf578fcb5c9f 100644 --- a/examples/readers/load-assemblyai.ts +++ b/examples/readers/load-assemblyai.ts @@ -48,7 +48,7 @@ program break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); console.log(response.toString()); } diff --git a/examples/readers/load-csv.ts b/examples/readers/load-csv.ts index 1e0a11237297b086fbc0c2d20b38ef6bea68c818..d1694519749ee4d7ec5e648b1b46930a7c601ad7 100644 --- a/examples/readers/load-csv.ts +++ b/examples/readers/load-csv.ts @@ -38,9 +38,9 @@ Given the CSV file, generate me Typescript code to answer the question: ${query} const queryEngine = index.asQueryEngine({ responseSynthesizer }); // Query the index - const response = await queryEngine.query( - "What is the correlation between survival and age?", - ); + const response = await queryEngine.query({ + query: "What is the correlation between survival and age?", + }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-docx.ts b/examples/readers/load-docx.ts index 61a8b314a6038e35bb7551493f7bb828c78d5a66..459dad3efda69f7f1936440802d10cb0fefa4af0 100644 --- a/examples/readers/load-docx.ts +++ b/examples/readers/load-docx.ts @@ -15,7 +15,7 @@ async function main() { // Test query const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(SAMPLE_QUERY); + const response = await queryEngine.query({ query: SAMPLE_QUERY }); console.log(`Test query > ${SAMPLE_QUERY}:\n`, response.toString()); } diff --git a/examples/readers/load-html.ts b/examples/readers/load-html.ts index 76672986180590d13a4b42c9912428ffe7598bb1..87ea89ec6df5120fcda6acd42bee1e2c09ba69db 100644 --- a/examples/readers/load-html.ts +++ b/examples/readers/load-html.ts @@ -10,9 +10,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What were the notable changes in 18.1?", - ); + const response = await queryEngine.query({ + query: "What were the notable changes in 18.1?", + }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-md.ts b/examples/readers/load-md.ts index bebc7a2ec14c04e2109712bc6f3a46223c952ccb..5e6e300afba782daaaab12e83925db8b5e4e79af 100644 --- a/examples/readers/load-md.ts +++ b/examples/readers/load-md.ts @@ -15,7 +15,7 @@ async function main() { // Test query const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query(SAMPLE_QUERY); + const response = await queryEngine.query({ query: SAMPLE_QUERY }); console.log(`Test query > ${SAMPLE_QUERY}:\n`, response.toString()); } diff --git a/examples/readers/load-notion.ts b/examples/readers/load-notion.ts index 7bbe065f46be2bc9703672d2ce95ce607080d86a..04f1651ef707cfd96f7c9c18c285e199b1914d95 100644 --- a/examples/readers/load-notion.ts +++ b/examples/readers/load-notion.ts @@ -79,7 +79,7 @@ program break; } - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); // Output response console.log(response.toString()); diff --git a/examples/readers/load-pdf.ts b/examples/readers/load-pdf.ts index afbe8f44801846ed8a517fda5cfbf3d02ad3f10f..0255d908f27a98bc79f1c966ef9a17c58d5ba0d2 100644 --- a/examples/readers/load-pdf.ts +++ b/examples/readers/load-pdf.ts @@ -10,7 +10,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query("What mistakes did they make?"); + const response = await queryEngine.query({ + query: "What mistakes did they make?", + }); // Output response console.log(response.toString()); diff --git a/examples/sentenceWindow.ts b/examples/sentenceWindow.ts index d4c9cf182834ee1641f9f3c3d8488a670d50206d..fcb89d99d7b946cf0f08da381f3b6bff76b80d72 100644 --- a/examples/sentenceWindow.ts +++ b/examples/sentenceWindow.ts @@ -31,9 +31,9 @@ async function main() { const queryEngine = index.asQueryEngine({ nodePostprocessors: [new MetadataReplacementPostProcessor("window")], }); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/storageContext.ts b/examples/storageContext.ts index 7326bbdbaf1c93082e9f0b5b343d4fc5b17d1f18..74fbd43fba868bb46ba11edf8347e1e6a6d97548 100644 --- a/examples/storageContext.ts +++ b/examples/storageContext.ts @@ -20,9 +20,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); @@ -35,9 +35,9 @@ async function main() { storageContext: secondStorageContext, }); const loadedQueryEngine = loadedIndex.asQueryEngine(); - const loadedResponse = await loadedQueryEngine.query( - "What did the author do growing up?", - ); + const loadedResponse = await loadedQueryEngine.query({ + query: "What did the author do growing up?", + }); console.log(loadedResponse.toString()); } diff --git a/examples/subquestion.ts b/examples/subquestion.ts index 6d46d91b5dfdba187606788aa0b174331a3112bc..b1e692e1f305668a21f8ac12076ea247cc7046e8 100644 --- a/examples/subquestion.ts +++ b/examples/subquestion.ts @@ -18,9 +18,9 @@ import essay from "./essay"; ], }); - const response = await queryEngine.query( - "How was Paul Grahams life different before and after YC?", - ); + const response = await queryEngine.query({ + query: "How was Paul Grahams life different before and after YC?", + }); console.log(response.toString()); })(); diff --git a/examples/summaryIndex.ts b/examples/summaryIndex.ts index cc34e95d392c167d441ccfa4ad115cf0d39bf7b8..d11a47031af81146ba8a3b45fc8250d82f4b4614 100644 --- a/examples/summaryIndex.ts +++ b/examples/summaryIndex.ts @@ -20,9 +20,9 @@ async function main() { const queryEngine = index.asQueryEngine({ retriever: index.asRetriever({ mode: SummaryRetrieverMode.LLM }), }); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.toString()); } diff --git a/examples/together-ai/vector-index.ts b/examples/together-ai/vector-index.ts index fc460ca81d42c45087b306a75897d1685a865ab2..94b5c5762f01237a052defb54bff12f8d05b653d 100644 --- a/examples/together-ai/vector-index.ts +++ b/examples/together-ai/vector-index.ts @@ -29,9 +29,9 @@ async function main() { const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); console.log(response.toString()); } diff --git a/examples/vectorIndex.ts b/examples/vectorIndex.ts index ff2691624842c2c61d2dd48e7dd449cea4e4a108..afa95e540798c7ca2e52bb19f73d59b2091f75b5 100644 --- a/examples/vectorIndex.ts +++ b/examples/vectorIndex.ts @@ -16,9 +16,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/vectorIndexAnthropic.ts b/examples/vectorIndexAnthropic.ts index c979900b285d6807c2dd6f8f030810973be70c70..c04c516c7dec4d620c4bbf71d70e23b141dabd41 100644 --- a/examples/vectorIndexAnthropic.ts +++ b/examples/vectorIndexAnthropic.ts @@ -35,9 +35,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine({ responseSynthesizer }); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts index b24e91416b0de23f97feb2659af0c9b55fbe57e7..e9013a0e6510b0cea292b53bd959bbcc787132b3 100644 --- a/examples/vectorIndexCustomize.ts +++ b/examples/vectorIndexCustomize.ts @@ -33,9 +33,9 @@ async function main() { [nodePostprocessor], ); - const response = await queryEngine.query( - "What did the author do growing up?", - ); + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); console.log(response.response); } diff --git a/examples/vectorIndexFromVectorStore.ts b/examples/vectorIndexFromVectorStore.ts index 311bc8c72225e73b1b2cc6f116b1899de22ed796..d539ae2ad122fc7c2ca09b728de67cea6c315c60 100644 --- a/examples/vectorIndexFromVectorStore.ts +++ b/examples/vectorIndexFromVectorStore.ts @@ -189,7 +189,9 @@ async function main() { }, }); - const response = await queryEngine.query("How many results do you have?"); + const response = await queryEngine.query({ + query: "How many results do you have?", + }); console.log(response.toString()); } diff --git a/examples/vectorIndexGPT4.ts b/examples/vectorIndexGPT4.ts index 6a7516dd5940a564cf98b70c2e9302571b46c18b..ed1ed20fcf9a651e58dc2b8d1870e731887dd82b 100644 --- a/examples/vectorIndexGPT4.ts +++ b/examples/vectorIndexGPT4.ts @@ -25,9 +25,9 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); - const response = await queryEngine.query( - "What did the author do in college?", - ); + const response = await queryEngine.query({ + query: "What did the author do in college?", + }); // Output response console.log(response.toString()); diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index c186090806615ca28609bab81d5eb38a755709f0..2f18a15a49511a419c9f7d0260eeb49bf5a494af 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,5 +1,5 @@ import { randomUUID } from "node:crypto"; -import { ChatHistory } from "./ChatHistory"; +import { ChatHistory, SimpleChatHistory } from "./ChatHistory"; import { NodeWithScore, TextNode } from "./Node"; import { CondenseQuestionPrompt, @@ -13,27 +13,40 @@ import { Response } from "./Response"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { Event } from "./callbacks/CallbackManager"; -import { ChatMessage, LLM, OpenAI } from "./llm"; +import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "./llm"; +import { streamConverter, streamReducer } from "./llm/utils"; import { BaseNodePostprocessor } from "./postprocessors"; +/** + * Represents the base parameters for ChatEngine. + */ +export interface ChatEngineParamsBase { + message: MessageContent; + /** + * Optional chat history if you want to customize the chat history. + */ + chatHistory?: ChatMessage[]; + history?: ChatHistory; +} + +export interface ChatEngineParamsStreaming extends ChatEngineParamsBase { + stream: true; +} + +export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { + stream?: false | null; +} + /** * A ChatEngine is used to handle back and forth chats between the application and the LLM. */ export interface ChatEngine { /** * Send message along with the class's current chat history to the LLM. - * @param message - * @param chatHistory optional chat history if you want to customize the chat history - * @param streaming optional streaming flag, which auto-sets the return value if True. + * @param params */ - chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[], - streaming?: T, - ): Promise<R>; + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; /** * Resets the chat history so that it's empty. @@ -53,48 +66,37 @@ export class SimpleChatEngine implements ChatEngine { this.llm = init?.llm ?? new OpenAI(); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[], - streaming?: T, - ): Promise<R> { - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - - //Non-streaming option - chatHistory = chatHistory ?? this.chatHistory; - chatHistory.push({ content: message, role: "user" }); - const response = await this.llm.chat({ messages: chatHistory }); - chatHistory.push(response.message); - this.chatHistory = chatHistory; - return new Response(response.message.content) as R; - } + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; - protected async *streamChat( - message: MessageContent, - chatHistory?: ChatMessage[], - ): AsyncGenerator<string, void, unknown> { - chatHistory = chatHistory ?? this.chatHistory; + const chatHistory = params.chatHistory ?? this.chatHistory; chatHistory.push({ content: message, role: "user" }); - const response_generator = await this.llm.chat({ - messages: chatHistory, - stream: true, - }); - var accumulator: string = ""; - for await (const part of response_generator) { - accumulator += part.delta; - yield part.delta; + if (stream) { + const stream = await this.llm.chat({ + messages: chatHistory, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); } - chatHistory.push({ content: accumulator, role: "assistant" }); - this.chatHistory = chatHistory; - return; + const response = await this.llm.chat({ messages: chatHistory }); + chatHistory.push(response.message); + return new Response(response.message.content); } reset() { @@ -115,7 +117,7 @@ export class SimpleChatEngine implements ChatEngine { export class CondenseQuestionChatEngine implements ChatEngine { queryEngine: BaseQueryEngine; chatHistory: ChatMessage[]; - serviceContext: ServiceContext; + llm: LLM; condenseMessagePrompt: CondenseQuestionPrompt; constructor(init: { @@ -126,8 +128,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { }) { this.queryEngine = init.queryEngine; this.chatHistory = init?.chatHistory ?? []; - this.serviceContext = - init?.serviceContext ?? serviceContextFromDefaults({}); + this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; this.condenseMessagePrompt = init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } @@ -135,7 +136,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { private async condenseQuestion(chatHistory: ChatMessage[], question: string) { const chatHistoryStr = messagesToHistoryStr(chatHistory); - return this.serviceContext.llm.complete({ + return this.llm.complete({ prompt: defaultCondenseQuestionPrompt({ question: question, chatHistory: chatHistoryStr, @@ -143,26 +144,39 @@ export class CondenseQuestionChatEngine implements ChatEngine { }); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - streaming?: T, - ): Promise<R> { - chatHistory = chatHistory ?? this.chatHistory; + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory ?? this.chatHistory; const condensedQuestion = ( await this.condenseQuestion(chatHistory, extractText(message)) ).text; - - const response = await this.queryEngine.query(condensedQuestion); - chatHistory.push({ content: message, role: "user" }); + + if (stream) { + const stream = await this.queryEngine.query({ + query: condensedQuestion, + stream: true, + }); + return streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.response), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }); + } + const response = await this.queryEngine.query({ + query: condensedQuestion, + }); chatHistory.push({ content: response.response, role: "assistant" }); - return response as R; + return response; } reset() { @@ -255,21 +269,13 @@ export class ContextChatEngine implements ChatEngine { }); } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - streaming?: T, - ): Promise<R> { - chatHistory = chatHistory ?? this.chatHistory; - - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory ?? this.chatHistory; const parentEvent: Event = { id: randomUUID(), type: "wrapper", @@ -279,57 +285,33 @@ export class ContextChatEngine implements ChatEngine { extractText(message), parentEvent, ); - + const nodes = context.nodes.map((r) => r.node); chatHistory.push({ content: message, role: "user" }); + if (stream) { + const stream = await this.chatModel.chat({ + messages: [context.message, ...chatHistory], + parentEvent, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.push({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta, nodes), + ); + } const response = await this.chatModel.chat({ messages: [context.message, ...chatHistory], parentEvent, }); chatHistory.push(response.message); - - this.chatHistory = chatHistory; - - return new Response( - response.message.content, - context.nodes.map((r) => r.node), - ) as R; - } - - protected async *streamChat( - message: MessageContent, - chatHistory?: ChatMessage[] | undefined, - ): AsyncGenerator<string, void, unknown> { - chatHistory = chatHistory ?? this.chatHistory; - - const parentEvent: Event = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - const context = await this.contextGenerator.generate( - extractText(message), - parentEvent, - ); - - chatHistory.push({ content: message, role: "user" }); - - const response_stream = await this.chatModel.chat({ - messages: [context.message, ...chatHistory], - parentEvent, - stream: true, - }); - var accumulator: string = ""; - for await (const part of response_stream) { - accumulator += part.delta; - yield part.delta; - } - - chatHistory.push({ content: accumulator, role: "assistant" }); - - this.chatHistory = chatHistory; - - return; + return new Response(response.message.content, nodes); } reset() { @@ -382,50 +364,38 @@ export class HistoryChatEngine { this.contextGenerator = init?.contextGenerator; } - async chat< - T extends boolean | undefined = undefined, - R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >( - message: MessageContent, - chatHistory: ChatHistory, - streaming?: T, - ): Promise<R> { - //Streaming option - if (streaming) { - return this.streamChat(message, chatHistory) as R; - } - const requestMessages = await this.prepareRequestMessages( - message, - chatHistory, - ); - const response = await this.llm.chat({ messages: requestMessages }); - chatHistory.addMessage(response.message); - return new Response(response.message.content) as R; - } - - protected async *streamChat( - message: MessageContent, - chatHistory: ChatHistory, - ): AsyncGenerator<string, void, unknown> { + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream, history } = params; + const chatHistory = history ?? new SimpleChatHistory(); const requestMessages = await this.prepareRequestMessages( message, chatHistory, ); - const response_stream = await this.llm.chat({ - messages: requestMessages, - stream: true, - }); - var accumulator = ""; - for await (const part of response_stream) { - accumulator += part.delta; - yield part.delta; + if (stream) { + const stream = await this.llm.chat({ + messages: requestMessages, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.addMessage({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); } - chatHistory.addMessage({ - content: accumulator, - role: "assistant", - }); - return; + const response = await this.llm.chat({ messages: requestMessages }); + chatHistory.addMessage(response.message); + return new Response(response.message.content); } private async prepareRequestMessages( diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index fb059459ffd1c23a9b73a66af948b4023a6f4133..0c709504baf1d2a6d0e43b219e9c9fde3c9cb15c 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -17,16 +17,32 @@ import { ResponseSynthesizer, } from "./synthesizers"; +/** + * Parameters for sending a query. + */ +export interface QueryEngineParamsBase { + query: string; + parentEvent?: Event; +} + +export interface QueryEngineParamsStreaming extends QueryEngineParamsBase { + stream: true; +} + +export interface QueryEngineParamsNonStreaming extends QueryEngineParamsBase { + stream?: false | null; +} + /** * A query engine is a question answerer that can use one or more steps. */ export interface BaseQueryEngine { /** * Query the query engine and get a response. - * @param query - * @param parentEvent + * @param params */ - query(query: string, parentEvent?: Event): Promise<Response>; + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; } /** @@ -70,17 +86,30 @@ export class RetrieverQueryEngine implements BaseQueryEngine { return this.applyNodePostprocessors(nodes); } - async query(query: string, parentEvent?: Event) { - const _parentEvent: Event = parentEvent || { + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; + const parentEvent: Event = params.parentEvent || { id: randomUUID(), type: "wrapper", tags: ["final"], }; - const nodesWithScore = await this.retrieve(query, _parentEvent); + const nodesWithScore = await this.retrieve(query, parentEvent); + if (stream) { + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + stream: true, + }); + } return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent: _parentEvent, + parentEvent, }); } } @@ -135,11 +164,16 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { }); } - async query(query: string): Promise<Response> { + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; const subQuestions = await this.questionGen.generate(this.metadatas, query); // groups final retrieval+synthesis operation - const parentEvent: Event = { + const parentEvent: Event = params.parentEvent || { id: randomUUID(), type: "wrapper", tags: ["final"], @@ -160,6 +194,14 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const nodesWithScore = subQNodes .filter((node) => node !== null) .map((node) => node as NodeWithScore); + if (stream) { + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + stream: true, + }); + } return this.responseSynthesizer.synthesize({ query, nodesWithScore, @@ -175,7 +217,10 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const question = subQ.subQuestion; const queryEngine = this.queryEngines[subQ.toolName]; - const response = await queryEngine.query(question, parentEvent); + const response = await queryEngine.query({ + query: question, + parentEvent, + }); const responseText = response.response; const nodeText = `Sub question: ${question}\nResponse: ${responseText}`; const node = new TextNode({ text: nodeText }); diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index a57435cceca6d0456ecfb92317f131ca363d37bd..e6ee82e4daaf09017ee5b5092606743920a7dd40 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -27,6 +27,7 @@ import { import { OpenAISession, getOpenAISession } from "./openai"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate"; +import { streamConverter } from "./utils"; export type MessageType = | "user" @@ -127,14 +128,6 @@ export interface LLM { export abstract class BaseLLM implements LLM { abstract metadata: LLMMetadata; - private async *chatToComplete( - stream: AsyncIterable<ChatResponseChunk>, - ): AsyncIterable<CompletionResponse> { - for await (const chunk of stream) { - yield { text: chunk.delta }; - } - } - complete( params: LLMCompletionParamsStreaming, ): Promise<AsyncIterable<CompletionResponse>>; @@ -151,7 +144,11 @@ export abstract class BaseLLM implements LLM { parentEvent, stream: true, }); - return this.chatToComplete(stream); + return streamConverter(stream, (chunk) => { + return { + text: chunk.delta, + }; + }); } const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], @@ -392,6 +389,7 @@ export class OpenAI extends BaseLLM { type: "llmPredict" as EventType, }; + // TODO: add callback to streamConverter and use streamConverter here //Indices var idx_counter: number = 0; for await (const part of chunk_stream) { diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 3a693a9b764231c7a0ffa2f773dc5aa0a1ce3ca4..55fc5f78d4765bb7f7c78775075ff81fb5f2c18e 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,5 +1,3 @@ -// TODO: use for LLM.ts - export async function* streamConverter<S, D>( stream: AsyncIterable<S>, converter: (s: S) => D, @@ -8,3 +6,19 @@ export async function* streamConverter<S, D>( yield converter(data); } } + +export async function* streamReducer<S, D>(params: { + stream: AsyncIterable<S>; + reducer: (previousValue: D, currentValue: S) => D; + initialValue: D; + finished?: (value: D | undefined) => void; +}): AsyncIterable<S> { + let value = params.initialValue; + for await (const data of params.stream) { + value = params.reducer(value, data); + yield data; + } + if (params.finished) { + params.finished(value); + } +} diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 38a95028b9242943d9d30167ad0c52e116630c42..da56efd1c35f92a8d602a5d76e8a843a9ca39ce5 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -67,7 +67,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }); const queryEngine = vectorStoreIndex.asQueryEngine(); const query = "What is the author's name?"; - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { @@ -145,7 +145,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { responseSynthesizer, }); const query = "What is the author's name?"; - const response = await queryEngine.query(query); + const response = await queryEngine.query({ query }); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ {