diff --git a/apps/simple/chatEngine.ts b/apps/simple/chatEngine.ts index ca1d51c552de1e7c81be6ebdc717d7a42e1e9fc9..97485b49701bf67ec9a1d96fbc6e02a5da340500 100644 --- a/apps/simple/chatEngine.ts +++ b/apps/simple/chatEngine.ts @@ -25,7 +25,7 @@ async function main() { while (true) { const query = await rl.question("Query: "); - const response = await chatEngine.achat(query); + const response = await chatEngine.chat(query); console.log(response); } } diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 2f9a36de4bfac29f209a3782cdb952bb8eb4618d..94b9362e1cdd8aa9aa708b7bfec65fdf9ea0faf6 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -22,7 +22,7 @@ export interface ChatEngine { * @param message * @param chatHistory optional chat history if you want to customize the chat history */ - achat(message: string, chatHistory?: ChatMessage[]): Promise<Response>; + chat(message: string, chatHistory?: ChatMessage[]): Promise<Response>; /** * Resets the chat history so that it's empty. @@ -42,10 +42,10 @@ export class SimpleChatEngine implements ChatEngine { this.llm = init?.llm ?? new OpenAI(); } - async achat(message: string, chatHistory?: ChatMessage[]): Promise<Response> { + async chat(message: string, chatHistory?: ChatMessage[]): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; chatHistory.push({ content: message, role: "user" }); - const response = await this.llm.achat(chatHistory); + const response = await this.llm.chat(chatHistory); chatHistory.push(response.message); this.chatHistory = chatHistory; return new Response(response.message.content); @@ -57,7 +57,7 @@ export class SimpleChatEngine implements ChatEngine { } /** - * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorIndex). + * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). * It does two steps on taking a user's chat message: first, it condenses the chat message * with the previous chat history into a question with more context. * Then, it queries the underlying Index using the new question with context and returns @@ -86,13 +86,10 @@ export class CondenseQuestionChatEngine implements ChatEngine { init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; } - private async acondenseQuestion( - chatHistory: ChatMessage[], - question: string - ) { + private async condenseQuestion(chatHistory: ChatMessage[], question: string) { const chatHistoryStr = messagesToHistoryStr(chatHistory); - return this.serviceContext.llmPredictor.apredict( + return this.serviceContext.llmPredictor.predict( defaultCondenseQuestionPrompt, { question: question, @@ -101,18 +98,15 @@ export class CondenseQuestionChatEngine implements ChatEngine { ); } - async achat( + async chat( message: string, chatHistory?: ChatMessage[] | undefined ): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; - const condensedQuestion = await this.acondenseQuestion( - chatHistory, - message - ); + const condensedQuestion = await this.condenseQuestion(chatHistory, message); - const response = await this.queryEngine.aquery(condensedQuestion); + const response = await this.queryEngine.query(condensedQuestion); chatHistory.push({ content: message, role: "user" }); chatHistory.push({ content: response.response, role: "assistant" }); @@ -150,7 +144,7 @@ export class ContextChatEngine implements ChatEngine { throw new Error("Method not implemented."); } - async achat(message: string, chatHistory?: ChatMessage[] | undefined) { + async chat(message: string, chatHistory?: ChatMessage[] | undefined) { chatHistory = chatHistory ?? this.chatHistory; const parentEvent: Event = { @@ -158,7 +152,7 @@ export class ContextChatEngine implements ChatEngine { type: "wrapper", tags: ["final"], }; - const sourceNodesWithScore = await this.retriever.aretrieve( + const sourceNodesWithScore = await this.retriever.retrieve( message, parentEvent ); @@ -174,7 +168,7 @@ export class ContextChatEngine implements ChatEngine { chatHistory.push({ content: message, role: "user" }); - const response = await this.chatModel.achat( + const response = await this.chatModel.chat( [systemMessage, ...chatHistory], parentEvent ); diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts index d5719ac026418ed3c7c073cf5d0912cfdf46959b..8e3ded3a2f88390b69fc49a5f6757b38766be127 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/Embedding.ts @@ -202,8 +202,8 @@ export abstract class BaseEmbedding { return similarity(embedding1, embedding2, mode); } - abstract aGetTextEmbedding(text: string): Promise<number[]>; - abstract aGetQueryEmbedding(query: string): Promise<number[]>; + abstract getTextEmbedding(text: string): Promise<number[]>; + abstract getQueryEmbedding(query: string): Promise<number[]>; } enum OpenAIEmbeddingModelType { @@ -221,7 +221,7 @@ export class OpenAIEmbedding extends BaseEmbedding { this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002; } - private async _aGetOpenAIEmbedding(input: string) { + private async getOpenAIEmbedding(input: string) { input = input.replace(/\n/g, " "); //^ NOTE this performance helper is in the OpenAI python library but may not be in the JS library @@ -233,11 +233,11 @@ export class OpenAIEmbedding extends BaseEmbedding { return data.data[0].embedding; } - async aGetTextEmbedding(text: string): Promise<number[]> { - return this._aGetOpenAIEmbedding(text); + async getTextEmbedding(text: string): Promise<number[]> { + return this.getOpenAIEmbedding(text); } - async aGetQueryEmbedding(query: string): Promise<number[]> { - return this._aGetOpenAIEmbedding(query); + async getQueryEmbedding(query: string): Promise<number[]> { + return this.getOpenAIEmbedding(query); } } diff --git a/packages/core/src/LLM.ts b/packages/core/src/LLM.ts index f2ea0a11e182483acb6fef910a1ac3b05a7f9f8e..20b706610fa2677e0f377e84baa9acae35766cc0 100644 --- a/packages/core/src/LLM.ts +++ b/packages/core/src/LLM.ts @@ -32,13 +32,13 @@ export interface LLM { * Get a chat response from the LLM * @param messages */ - achat(messages: ChatMessage[]): Promise<ChatResponse>; + chat(messages: ChatMessage[]): Promise<ChatResponse>; /** * Get a prompt completion from the LLM * @param prompt the prompt to complete */ - acomplete(prompt: string): Promise<CompletionResponse>; + complete(prompt: string): Promise<CompletionResponse>; } export const GPT4_MODELS = { @@ -100,7 +100,7 @@ export class OpenAI implements LLM { } } - async achat( + async chat( messages: ChatMessage[], parentEvent?: Event ): Promise<ChatResponse> { @@ -142,10 +142,10 @@ export class OpenAI implements LLM { } } - async acomplete( + async complete( prompt: string, parentEvent?: Event ): Promise<CompletionResponse> { - return this.achat([{ content: prompt, role: "user" }], parentEvent); + return this.chat([{ content: prompt, role: "user" }], parentEvent); } } diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index 69a6aa5055290e398b179a62985ed676aa5d9d61..69c3c62843f33bc0975e410089ae75eba4fab468 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -7,7 +7,7 @@ import { CallbackManager, Event } from "./callbacks/CallbackManager"; */ export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; - apredict( + predict( prompt: string | SimplePrompt, input?: Record<string, string>, parentEvent?: Event @@ -46,16 +46,16 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { throw new Error("Not implemented yet"); } - async apredict( + async predict( prompt: string | SimplePrompt, input?: Record<string, string>, parentEvent?: Event ): Promise<string> { if (typeof prompt === "string") { - const result = await this.languageModel.acomplete(prompt, parentEvent); + const result = await this.languageModel.complete(prompt, parentEvent); return result.message.content; } else { - return this.apredict(prompt(input ?? {})); + return this.predict(prompt(input ?? {})); } } } diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index e574c6cf5e110b8e8ac3228d890da05cee050e0b..b4a3bc115ec30d7eae87a3a4a66a8ab48f549b35 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -16,7 +16,12 @@ import { QueryEngineTool, ToolMetadata } from "./Tool"; * A query engine is a question answerer that can use one or more steps. */ export interface BaseQueryEngine { - aquery(query: string, parentEvent?: Event): Promise<Response>; + /** + * Query the query engine and get a response. + * @param query + * @param parentEvent + */ + query(query: string, parentEvent?: Event): Promise<Response>; } /** @@ -37,14 +42,14 @@ export class RetrieverQueryEngine implements BaseQueryEngine { responseSynthesizer || new ResponseSynthesizer({ serviceContext }); } - async aquery(query: string, parentEvent?: Event) { + async query(query: string, parentEvent?: Event) { const _parentEvent: Event = parentEvent || { id: uuidv4(), type: "wrapper", tags: ["final"], }; - const nodes = await this.retriever.aretrieve(query, _parentEvent); - return this.responseSynthesizer.asynthesize(query, nodes, _parentEvent); + const nodes = await this.retriever.retrieve(query, _parentEvent); + return this.responseSynthesizer.synthesize(query, nodes, _parentEvent); } } @@ -98,11 +103,8 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { }); } - async aquery(query: string): Promise<Response> { - const subQuestions = await this.questionGen.agenerate( - this.metadatas, - query - ); + async query(query: string): Promise<Response> { + const subQuestions = await this.questionGen.generate(this.metadatas, query); // groups final retrieval+synthesis operation const parentEvent: Event = { @@ -120,16 +122,16 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { }; const subQNodes = await Promise.all( - subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent)) + subQuestions.map((subQ) => this.querySubQ(subQ, subQueryParentEvent)) ); const nodes = subQNodes .filter((node) => node !== null) .map((node) => node as NodeWithScore); - return this.responseSynthesizer.asynthesize(query, nodes, parentEvent); + return this.responseSynthesizer.synthesize(query, nodes, parentEvent); } - private async aquerySubQ( + private async querySubQ( subQ: SubQuestion, parentEvent?: Event ): Promise<NodeWithScore | null> { @@ -137,7 +139,7 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const question = subQ.subQuestion; const queryEngine = this.queryEngines[subQ.toolName]; - const response = await queryEngine.aquery(question, parentEvent); + const response = await queryEngine.query(question, parentEvent); const responseText = response.response; const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`; const node = new TextNode({ text: nodeText }); diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index bf014e9f98f0a1f2c5e8726c68ba5924591a5e26..46bdb60ff0f307e80302707a5d0cc5ed8868c1b5 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -20,7 +20,7 @@ export interface SubQuestion { * QuestionGenerators generate new questions for the LLM using tools and a user query. */ export interface BaseQuestionGenerator { - agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; + generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; } /** @@ -37,13 +37,10 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator { this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); } - async agenerate( - tools: ToolMetadata[], - query: string - ): Promise<SubQuestion[]> { + async generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]> { const toolsStr = buildToolsText(tools); const queryStr = query; - const prediction = await this.llmPredictor.apredict(this.prompt, { + const prediction = await this.llmPredictor.predict(this.prompt, { toolsStr, queryStr, }); diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index beb97adf7c0a38410beabdeba851cb17ee8ec1c8..c0e5faaa893ebc0b9478e5a1d51e3713c147e522 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -20,7 +20,7 @@ interface BaseResponseBuilder { * @param textChunks * @param parentEvent */ - agetResponse( + getResponse( query: string, textChunks: string[], parentEvent?: Event @@ -40,7 +40,7 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { this.textQATemplate = defaultTextQaPrompt; } - async agetResponse( + async getResponse( query: string, textChunks: string[], parentEvent?: Event @@ -51,7 +51,7 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { }; const prompt = this.textQATemplate(input); - return this.llmPredictor.apredict(prompt, {}, parentEvent); + return this.llmPredictor.predict(prompt, {}, parentEvent); } } @@ -73,7 +73,7 @@ export class Refine implements BaseResponseBuilder { this.refineTemplate = refineTemplate ?? defaultRefinePrompt; } - async agetResponse( + async getResponse( query: string, textChunks: string[], prevResponse?: any @@ -106,7 +106,7 @@ export class Refine implements BaseResponseBuilder { for (const chunk of textChunks) { if (!response) { - response = await this.serviceContext.llmPredictor.apredict( + response = await this.serviceContext.llmPredictor.predict( textQATemplate, { context: chunk, @@ -133,7 +133,7 @@ export class Refine implements BaseResponseBuilder { ]); for (const chunk of textChunks) { - response = await this.serviceContext.llmPredictor.apredict( + response = await this.serviceContext.llmPredictor.predict( refineTemplate, { context: chunk, @@ -149,7 +149,7 @@ export class Refine implements BaseResponseBuilder { * CompactAndRefine is a slight variation of Refine that first compacts the text chunks into the smallest possible number of chunks. */ export class CompactAndRefine extends Refine { - async agetResponse( + async getResponse( query: string, textChunks: string[], prevResponse?: any @@ -164,7 +164,7 @@ export class CompactAndRefine extends Refine { maxPrompt, textChunks ); - const response = super.agetResponse(query, newTexts, prevResponse); + const response = super.getResponse(query, newTexts, prevResponse); return response; } } @@ -178,7 +178,7 @@ export class TreeSummarize implements BaseResponseBuilder { this.serviceContext = serviceContext; } - async agetResponse(query: string, textChunks: string[]): Promise<string> { + async getResponse(query: string, textChunks: string[]): Promise<string> { const summaryTemplate: SimplePrompt = (input) => defaultTextQaPrompt({ ...input, query: query }); @@ -192,19 +192,19 @@ export class TreeSummarize implements BaseResponseBuilder { ); if (packedTextChunks.length === 1) { - return this.serviceContext.llmPredictor.apredict(summaryTemplate, { + return this.serviceContext.llmPredictor.predict(summaryTemplate, { context: packedTextChunks[0], }); } else { const summaries = await Promise.all( packedTextChunks.map((chunk) => - this.serviceContext.llmPredictor.apredict(summaryTemplate, { + this.serviceContext.llmPredictor.predict(summaryTemplate, { context: chunk, }) ) ); - return this.agetResponse(query, summaries); + return this.getResponse(query, summaries); } } } @@ -234,15 +234,11 @@ export class ResponseSynthesizer { responseBuilder ?? getResponseBuilder(this.serviceContext); } - async asynthesize( - query: string, - nodes: NodeWithScore[], - parentEvent?: Event - ) { + async synthesize(query: string, nodes: NodeWithScore[], parentEvent?: Event) { let textChunks: string[] = nodes.map((node) => node.node.getContent(MetadataMode.NONE) ); - const response = await this.responseBuilder.agetResponse( + const response = await this.responseBuilder.getResponse( query, textChunks, parentEvent diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index 66e672e9ef16e36e7b2e3e153d094babb5dcbc26..303d0fa541a9797e7049d9629a2eba45c8e81a20 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -6,6 +6,6 @@ import { Event } from "./callbacks/CallbackManager"; * Retrievers retrieve the nodes that most closely match our query in similarity. */ export interface BaseRetriever { - aretrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>; + retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>; getServiceContext(): ServiceContext; } diff --git a/packages/core/src/indices/list/ListIndexRetriever.ts b/packages/core/src/indices/list/ListIndexRetriever.ts index 15b6d9c2e88a063a91167f78bb73a2061837fe25..d0761a269ae6fbee969d40658a23e6fe253c4070 100644 --- a/packages/core/src/indices/list/ListIndexRetriever.ts +++ b/packages/core/src/indices/list/ListIndexRetriever.ts @@ -23,10 +23,7 @@ export class ListIndexRetriever implements BaseRetriever { this.index = index; } - async aretrieve( - query: string, - parentEvent?: Event - ): Promise<NodeWithScore[]> { + async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { const nodeIds = this.index.indexStruct.nodes; const nodes = await this.index.docStore.getNodes(nodeIds); const result = nodes.map((node) => ({ @@ -81,10 +78,7 @@ export class ListIndexLLMRetriever implements BaseRetriever { this.serviceContext = serviceContext || index.serviceContext; } - async aretrieve( - query: string, - parentEvent?: Event - ): Promise<NodeWithScore[]> { + async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { const nodeIds = this.index.indexStruct.nodes; const results: NodeWithScore[] = []; @@ -94,7 +88,7 @@ export class ListIndexLLMRetriever implements BaseRetriever { const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); const input = { context: fmtBatchStr, query: query }; - const rawResponse = await this.serviceContext.llmPredictor.apredict( + const rawResponse = await this.serviceContext.llmPredictor.predict( this.choiceSelectPrompt, input ); diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index c8106b90f1e0cec5c7fca58bf2e2bfb955aabef7..862481640b1d8e6b509adaa7e1ce165e1998248d 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -24,12 +24,9 @@ export class VectorIndexRetriever implements BaseRetriever { this.serviceContext = this.index.serviceContext; } - async aretrieve( - query: string, - parentEvent?: Event - ): Promise<NodeWithScore[]> { + async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { const queryEmbedding = - await this.serviceContext.embedModel.aGetQueryEmbedding(query); + await this.serviceContext.embedModel.getQueryEmbedding(query); const q: VectorStoreQuery = { queryEmbedding: queryEmbedding, diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index a50e29fea3e509542a19108d261f95a63ce2b5c7..e9c9631e7eeeb60eece7d22c120c13db10382031 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -79,7 +79,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { * @param logProgress log progress to console (useful for debugging) * @returns */ - static async agetNodeEmbeddingResults( + static async getNodeEmbeddingResults( nodes: BaseNode[], serviceContext: ServiceContext, logProgress = false @@ -91,7 +91,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { if (logProgress) { console.log(`getting embedding for node ${i}/${nodes.length}`); } - const embedding = await serviceContext.embedModel.aGetTextEmbedding( + const embedding = await serviceContext.embedModel.getTextEmbedding( node.getContent(MetadataMode.EMBED) ); nodesWithEmbeddings.push({ node, embedding }); @@ -112,7 +112,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { serviceContext: ServiceContext, vectorStore: VectorStore ): Promise<IndexDict> { - const embeddingResults = await this.agetNodeEmbeddingResults( + const embeddingResults = await this.getNodeEmbeddingResults( nodes, serviceContext ); diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index b32a05b7edab1cfae8a2db99cd2bcfdfca6d8fd3..b0dad0b54472c6e2ab4e5799514e8daf7e7109bf 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -72,7 +72,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { ); const queryEngine = vectorStoreIndex.asQueryEngine(); const query = "What is the author's name?"; - const response = await queryEngine.aquery(query); + const response = await queryEngine.query(query); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { @@ -152,7 +152,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { responseSynthesizer ); const query = "What is the author's name?"; - const response = await queryEngine.aquery(query); + const response = await queryEngine.query(query); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 21fd001c3f846ffc82df8207635fe3d88a219bb3..ccbadcf71a537a89fe64466c3a5286aa366660ac 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -11,7 +11,7 @@ export function mockLlmGeneration({ callbackManager: CallbackManager; }) { jest - .spyOn(languageModel, "achat") + .spyOn(languageModel, "chat") .mockImplementation( async (messages: ChatMessage[], parentEvent?: Event) => { const text = "MOCK_TOKEN_1-MOCK_TOKEN_2";