diff --git a/.gitignore b/.gitignore index d1595af422c385474a45c132cfb3bdea332b3467..2641a07fa5489a4a28b76c5e83474c459749f8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ yarn-error.log* # vercel .vercel + +storage/ diff --git a/apps/simple/chatEngine.ts b/apps/simple/chatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..9b31847d9327547e09cf5cec64e23ab3e8a78a76 --- /dev/null +++ b/apps/simple/chatEngine.ts @@ -0,0 +1,31 @@ +// @ts-ignore +import * as readline from "node:readline/promises"; +// @ts-ignore +import { stdin as input, stdout as output } from "node:process"; +import { Document } from "@llamaindex/core/src/Node"; +import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; +import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine"; +import essay from "./essay"; +import { serviceContextFromDefaults } from "@llamaindex/core/src/ServiceContext"; + +async function main() { + const document = new Document({ text: essay }); + const serviceContext = serviceContextFromDefaults({ chunkSize: 512 }); + const index = await VectorStoreIndex.fromDocuments( + [document], + undefined, + serviceContext + ); + const retriever = index.asRetriever(); + retriever.similarityTopK = 5; + const chatEngine = new ContextChatEngine({ retriever }); + const rl = readline.createInterface({ input, output }); + + while (true) { + const query = await rl.question("Query: "); + const response = await chatEngine.achat(query); + console.log(response); + } +} + +main().catch(console.error); diff --git a/apps/simple/listIndex.ts b/apps/simple/listIndex.ts new file mode 100644 index 0000000000000000000000000000000000000000..5b7a5203b36431a07ce367016d3f144dce13748d --- /dev/null +++ b/apps/simple/listIndex.ts @@ -0,0 +1,17 @@ +import { Document } from "@llamaindex/core/src/Node"; +import { ListIndex } from "@llamaindex/core/src/index/list"; +import essay from "./essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await ListIndex.fromDocuments([document]); + const queryEngine = index.asQueryEngine(); + const response = await queryEngine.aquery( + "What did the author do growing up?" + ); + console.log(response.toString()); +} + +main().catch((e: Error) => { + console.error(e, e.stack); +}); diff --git a/apps/simple/openai.ts b/apps/simple/openai.ts index 0977b8d359114b5b72c4af1c01c3e10421449085..200ed8ea39539614633613b65191970527b1a3a4 100644 --- a/apps/simple/openai.ts +++ b/apps/simple/openai.ts @@ -1,3 +1,4 @@ +// @ts-ignore import process from "node:process"; import { Configuration, OpenAIWrapper } from "@llamaindex/core/src/openai"; diff --git a/apps/simple/simple.txt b/apps/simple/simple.txt deleted file mode 100644 index 7cd89b8d0d51646af616cc9f784f61fcb1469bf1..0000000000000000000000000000000000000000 --- a/apps/simple/simple.txt +++ /dev/null @@ -1,9 +0,0 @@ -Simple flow: - -Get document list, in this case one document. -Split each document into nodes, in this case sentences or lines. -Embed each of the nodes and get vectors. Store them in memory for now. -Embed query. -Compare query with nodes and get the top n -Put the top n nodes into the prompt. -Execute prompt, get result. diff --git a/apps/simple/subquestion.ts b/apps/simple/subquestion.ts new file mode 100644 index 0000000000000000000000000000000000000000..a3a85273dd4ef8e812ef92fdf1756607324d4a92 --- /dev/null +++ b/apps/simple/subquestion.ts @@ -0,0 +1,60 @@ +// from llama_index import SimpleDirectoryReader, VectorStoreIndex +// from llama_index.query_engine import SubQuestionQueryEngine +// from llama_index.tools import QueryEngineTool, ToolMetadata + +// # load data +// pg_essay = SimpleDirectoryReader( +// input_dir="docs/examples/data/paul_graham/" +// ).load_data() + +// # build index and query engine +// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine() + +// # setup base query engine as tool +// query_engine_tools = [ +// QueryEngineTool( +// query_engine=query_engine, +// metadata=ToolMetadata( +// name="pg_essay", description="Paul Graham essay on What I Worked On" +// ), +// ) +// ] + +// query_engine = SubQuestionQueryEngine.from_defaults( +// query_engine_tools=query_engine_tools +// ) + +// response = query_engine.query( +// "How was Paul Grahams life different before and after YC?" +// ) + +// print(response) + +import { Document } from "@llamaindex/core/src/Node"; +import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; +import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine"; + +import essay from "./essay"; + +(async () => { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + + const queryEngine = SubQuestionQueryEngine.fromDefaults({ + queryEngineTools: [ + { + queryEngine: index.asQueryEngine(), + metadata: { + name: "pg_essay", + description: "Paul Graham essay on What I Worked On", + }, + }, + ], + }); + + const response = await queryEngine.aquery( + "How was Paul Grahams life different before and after YC?" + ); + + console.log(response); +})(); diff --git a/apps/simple/index.ts b/apps/simple/vectorIndex.ts similarity index 88% rename from apps/simple/index.ts rename to apps/simple/vectorIndex.ts index 733bb7f07fedcf7c8b74309fe384b214f7cec6cc..d05b5874984e3846ee2de6a55ceb1c29875a827a 100644 --- a/apps/simple/index.ts +++ b/apps/simple/vectorIndex.ts @@ -2,7 +2,7 @@ import { Document } from "@llamaindex/core/src/Node"; import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; import essay from "./essay"; -(async () => { +async function main() { const document = new Document({ text: essay }); const index = await VectorStoreIndex.fromDocuments([document]); const queryEngine = index.asQueryEngine(); @@ -10,4 +10,6 @@ import essay from "./essay"; "What did the author do growing up?" ); console.log(response.toString()); -})(); +} + +main().catch(console.error); diff --git a/packages/core/package.json b/packages/core/package.json index b303d036048822c443c8fe92241dd8c3f6a79e97..f0f5c71ed1cb4a98cf735be3502ae5aae9923e56 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -10,6 +10,9 @@ "uuid": "^9.0.0", "wink-nlp": "^1.14.1" }, + "engines": { + "node": ">=18.0.0" + }, "main": "src/index.ts", "types": "src/index.ts", "scripts": { diff --git a/packages/core/src/BaseIndex.ts b/packages/core/src/BaseIndex.ts index dd6e7461b20ea86a6133aad94dd0adf71a014748..0889be7585703289580c1e6b4921e9f1e8eee6e9 100644 --- a/packages/core/src/BaseIndex.ts +++ b/packages/core/src/BaseIndex.ts @@ -9,11 +9,11 @@ import { } from "./storage/StorageContext"; import { BaseDocumentStore } from "./storage/docStore/types"; import { VectorStore } from "./storage/vectorStore/types"; -export class IndexDict { +import { BaseIndexStore } from "./storage/indexStore/types"; + +export abstract class IndexStruct { indexId: string; summary?: string; - nodesDict: Record<string, BaseNode> = {}; - docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext constructor(indexId = uuidv4(), summary = undefined) { this.indexId = indexId; @@ -26,6 +26,18 @@ export class IndexDict { } return this.summary; } +} + +export class IndexDict extends IndexStruct { + nodesDict: Record<string, BaseNode> = {}; + docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext + + getSummary(): string { + if (this.summary === undefined) { + throw new Error("summary field of the index dict is not set"); + } + return this.summary; + } addNode(node: BaseNode, textId?: string) { const vectorId = textId ?? node.id_; @@ -33,18 +45,28 @@ export class IndexDict { } } +export class IndexList extends IndexStruct { + nodes: string[] = []; + + addNode(node: BaseNode) { + this.nodes.push(node.id_); + } +} + export interface BaseIndexInit<T> { serviceContext: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore: VectorStore; + vectorStore?: VectorStore; + indexStore?: BaseIndexStore; indexStruct: T; } export abstract class BaseIndex<T> { serviceContext: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore: VectorStore; + vectorStore?: VectorStore; + indexStore?: BaseIndexStore; indexStruct: T; constructor(init: BaseIndexInit<T>) { @@ -52,6 +74,7 @@ export abstract class BaseIndex<T> { this.storageContext = init.storageContext; this.docStore = init.docStore; this.vectorStore = init.vectorStore; + this.indexStore = init.indexStore; this.indexStruct = init.indexStruct; } @@ -65,9 +88,16 @@ export interface VectorIndexOptions { storageContext?: StorageContext; } +interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { + vectorStore: VectorStore; +} + export class VectorStoreIndex extends BaseIndex<IndexDict> { - private constructor(init: BaseIndexInit<IndexDict>) { + vectorStore: VectorStore; + + private constructor(init: VectorIndexConstructorProps) { super(init); + this.vectorStore = init.vectorStore; } static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> { diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..9d7af413d6007f2b3639a13beec885859e642d63 --- /dev/null +++ b/packages/core/src/ChatEngine.ts @@ -0,0 +1,177 @@ +import { BaseChatModel, BaseMessage, ChatOpenAI } from "./LanguageModel"; +import { TextNode } from "./Node"; +import { + SimplePrompt, + contextSystemPrompt, + defaultCondenseQuestionPrompt, + messagesToHistoryStr, +} from "./Prompt"; +import { BaseQueryEngine } from "./QueryEngine"; +import { Response } from "./Response"; +import { BaseRetriever } from "./Retriever"; +import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { v4 as uuidv4 } from "uuid"; +import { Event } from "./callbacks/CallbackManager"; + +interface ChatEngine { + chatRepl(): void; + + achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>; + + reset(): void; +} + +export class SimpleChatEngine implements ChatEngine { + chatHistory: BaseMessage[]; + llm: BaseChatModel; + + constructor(init?: Partial<SimpleChatEngine>) { + this.chatHistory = init?.chatHistory ?? []; + this.llm = init?.llm ?? new ChatOpenAI({ model: "gpt-3.5-turbo" }); + } + + chatRepl() { + throw new Error("Method not implemented."); + } + + async achat(message: string, chatHistory?: BaseMessage[]): Promise<Response> { + chatHistory = chatHistory ?? this.chatHistory; + chatHistory.push({ content: message, type: "human" }); + const response = await this.llm.agenerate(chatHistory); + chatHistory.push({ content: response.generations[0][0].text, type: "ai" }); + this.chatHistory = chatHistory; + return new Response(response.generations[0][0].text); + } + + reset() { + this.chatHistory = []; + } +} + +export class CondenseQuestionChatEngine implements ChatEngine { + queryEngine: BaseQueryEngine; + chatHistory: BaseMessage[]; + serviceContext: ServiceContext; + condenseMessagePrompt: SimplePrompt; + + constructor(init: { + queryEngine: BaseQueryEngine; + chatHistory: BaseMessage[]; + serviceContext?: ServiceContext; + condenseMessagePrompt?: SimplePrompt; + }) { + this.queryEngine = init.queryEngine; + this.chatHistory = init?.chatHistory ?? []; + this.serviceContext = + init?.serviceContext ?? serviceContextFromDefaults({}); + this.condenseMessagePrompt = + init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; + } + + private async acondenseQuestion( + chatHistory: BaseMessage[], + question: string + ) { + const chatHistoryStr = messagesToHistoryStr(chatHistory); + + return this.serviceContext.llmPredictor.apredict( + defaultCondenseQuestionPrompt, + { + question: question, + chat_history: chatHistoryStr, + } + ); + } + + async achat( + message: string, + chatHistory?: BaseMessage[] | undefined + ): Promise<Response> { + chatHistory = chatHistory ?? this.chatHistory; + + const condensedQuestion = await this.acondenseQuestion( + chatHistory, + message + ); + + const response = await this.queryEngine.aquery(condensedQuestion); + + chatHistory.push({ content: message, type: "human" }); + chatHistory.push({ content: response.response, type: "ai" }); + + return response; + } + + chatRepl() { + throw new Error("Method not implemented."); + } + + reset() { + this.chatHistory = []; + } +} + +export class ContextChatEngine implements ChatEngine { + retriever: BaseRetriever; + chatModel: BaseChatModel; + chatHistory: BaseMessage[]; + + constructor(init: { + retriever: BaseRetriever; + chatModel?: BaseChatModel; + chatHistory?: BaseMessage[]; + }) { + this.retriever = init.retriever; + this.chatModel = + init.chatModel ?? new ChatOpenAI({ model: "gpt-3.5-turbo-16k" }); + this.chatHistory = init?.chatHistory ?? []; + } + + chatRepl() { + throw new Error("Method not implemented."); + } + + async achat(message: string, chatHistory?: BaseMessage[] | undefined) { + chatHistory = chatHistory ?? this.chatHistory; + + const parentEvent: Event = { + id: uuidv4(), + type: "wrapper", + tags: ["final"], + }; + const sourceNodesWithScore = await this.retriever.aretrieve( + message, + parentEvent + ); + + const systemMessage: BaseMessage = { + content: contextSystemPrompt({ + context: sourceNodesWithScore + .map((r) => (r.node as TextNode).text) + .join("\n\n"), + }), + type: "system", + }; + + chatHistory.push({ content: message, type: "human" }); + + const response = await this.chatModel.agenerate( + [systemMessage, ...chatHistory], + parentEvent + ); + const text = response.generations[0][0].text; + + chatHistory.push({ content: text, type: "ai" }); + + this.chatHistory = chatHistory; + + return new Response( + text, + sourceNodesWithScore.map((r) => r.node) + ); + } + + reset() { + this.chatHistory = []; + } +} diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts index 7d2208669b3627bf66b16d67ca221b8c61b03455..bb824bc0efe5a929b922bf035601ae97a22e03db 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/Embedding.ts @@ -174,24 +174,12 @@ export function getTopKMMREmbeddings( } export abstract class BaseEmbedding { - static similarity( + similarity( embedding1: number[], embedding2: number[], - mode: SimilarityType = SimilarityType.DOT_PRODUCT + mode: SimilarityType = SimilarityType.DEFAULT ): number { - if (embedding1.length !== embedding2.length) { - throw new Error("Embedding length mismatch"); - } - - if (mode === SimilarityType.DOT_PRODUCT) { - let result = 0; - for (let i = 0; i < embedding1.length; i++) { - result += embedding1[i] * embedding2[i]; - } - return result; - } else { - throw new Error("Not implemented yet"); - } + return similarity(embedding1, embedding2, mode); } abstract aGetTextEmbedding(text: string): Promise<number[]>; diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts index dfcafe21b2dcfb54e5a653daa11217d0b4a1c755..58a154d33a598cbffa0b39c8d55cba21684e8ea4 100644 --- a/packages/core/src/GlobalsHelper.ts +++ b/packages/core/src/GlobalsHelper.ts @@ -1,4 +1,4 @@ -import { Trace } from "./callbacks/CallbackManager"; +import { Event, EventTag, EventType } from "./callbacks/CallbackManager"; import { v4 as uuidv4 } from "uuid"; class GlobalsHelper { @@ -17,10 +17,21 @@ class GlobalsHelper { return this.defaultTokenizer; } - createTrace({ parentTrace }: { parentTrace?: Trace }): Trace { + createEvent({ + parentEvent, + type, + tags, + }: { + parentEvent?: Event; + type: EventType; + tags?: EventTag[]; + }): Event { return { id: uuidv4(), - parentId: parentTrace?.id, + type, + // inherit parent tags if tags not set + tags: tags || parentEvent?.tags, + parentId: parentEvent?.id, }; } } diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index 0a0f74725544efecb4ffce5a737c2dcbe25a1fdb..93a47cd8060fd5e91a17930bcfe6d6e3cff31c74 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,16 +1,18 @@ import { ChatOpenAI } from "./LanguageModel"; import { SimplePrompt } from "./Prompt"; -import { CallbackManager, Trace } from "./callbacks/CallbackManager"; +import { CallbackManager, Event } from "./callbacks/CallbackManager"; +// TODO change this to LLM class export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; apredict( prompt: string | SimplePrompt, input?: Record<string, string>, - parentTrace?: Trace + parentEvent?: Event ): Promise<string>; } +// TODO change this to LLM class export class ChatGPTLLMPredictor implements BaseLLMPredictor { model: string; retryOnThrottling: boolean; @@ -52,7 +54,7 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { async apredict( prompt: string | SimplePrompt, input?: Record<string, string>, - parentTrace?: Trace + parentEvent?: Event ): Promise<string> { if (typeof prompt === "string") { const result = await this.languageModel.agenerate( @@ -62,7 +64,7 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { type: "human", }, ], - parentTrace + parentEvent ); return result.generations[0][0].text; } else { diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LanguageModel.ts index 1505d30258e35227ab88155fa65a0aa51faed577..9f9c676e91023742a41135998a2ff8a6c403e93e 100644 --- a/packages/core/src/LanguageModel.ts +++ b/packages/core/src/LanguageModel.ts @@ -1,4 +1,4 @@ -import { CallbackManager, Trace } from "./callbacks/CallbackManager"; +import { CallbackManager, Event } from "./callbacks/CallbackManager"; import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream"; import { ChatCompletionRequestMessageRoleEnum, @@ -25,9 +25,11 @@ export interface LLMResult { generations: Generation[][]; // Each input can have more than one generations } -export class BaseChatModel implements BaseLanguageModel {} +export interface BaseChatModel extends BaseLanguageModel { + agenerate(messages: BaseMessage[], parentEvent?: Event): Promise<LLMResult>; +} -export class ChatOpenAI extends BaseChatModel { +export class ChatOpenAI implements BaseChatModel { model: string; temperature: number = 0.7; openAIKey: string | null = null; @@ -45,7 +47,6 @@ export class ChatOpenAI extends BaseChatModel { model: string; callbackManager?: CallbackManager; }) { - super(); this.model = model; this.callbackManager = callbackManager; this.session = getOpenAISession(); @@ -70,7 +71,7 @@ export class ChatOpenAI extends BaseChatModel { async agenerate( messages: BaseMessage[], - parentTrace?: Trace + parentEvent?: Event ): Promise<LLMResult> { const baseRequestParams: CreateChatCompletionRequest = { model: this.model, @@ -94,7 +95,7 @@ export class ChatOpenAI extends BaseChatModel { const fullResponse = await aHandleOpenAIStream({ response, onLLMStream: this.callbackManager.onLLMStream, - parentTrace, + parentEvent, }); return { generations: [[{ text: fullResponse }]] }; } diff --git a/packages/core/src/NodeParser.ts b/packages/core/src/NodeParser.ts index 52db5dfbae7bfc1471d98099e0d7d646cb3c2756..f3e173f94b2449695177c005dce1b671a73923ee 100644 --- a/packages/core/src/NodeParser.ts +++ b/packages/core/src/NodeParser.ts @@ -1,5 +1,6 @@ import { Document, NodeRelationship, TextNode } from "./Node"; import { SentenceSplitter } from "./TextSplitter"; +import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants"; export function getTextSplitsFromDocument( document: Document, @@ -13,18 +14,36 @@ export function getTextSplitsFromDocument( export function getNodesFromDocument( document: Document, - textSplitter: SentenceSplitter + textSplitter: SentenceSplitter, + includeMetadata: boolean = true, + includePrevNextRel: boolean = true ) { let nodes: TextNode[] = []; const textSplits = getTextSplitsFromDocument(document, textSplitter); - textSplits.forEach((textSplit, index) => { - const node = new TextNode({ text: textSplit }); + textSplits.forEach((textSplit) => { + const node = new TextNode({ + text: textSplit, + metadata: includeMetadata ? document.metadata : {}, + }); node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo(); nodes.push(node); }); + if (includePrevNextRel) { + nodes.forEach((node, index) => { + if (index > 0) { + node.relationships[NodeRelationship.PREVIOUS] = + nodes[index - 1].asRelatedNodeInfo(); + } + if (index < nodes.length - 1) { + node.relationships[NodeRelationship.NEXT] = + nodes[index + 1].asRelatedNodeInfo(); + } + }); + } + return nodes; } @@ -33,17 +52,34 @@ export interface NodeParser { } export class SimpleNodeParser implements NodeParser { textSplitter: SentenceSplitter; + includeMetadata: boolean; + includePrevNextRel: boolean; + + constructor(init?: { + textSplitter?: SentenceSplitter; + includeMetadata?: boolean; + includePrevNextRel?: boolean; - constructor( - textSplitter: any = null, - includeExtraInfo: boolean = true, - includePrevNextRel: boolean = true - ) { - this.textSplitter = textSplitter ?? new SentenceSplitter(); + chunkSize?: number; + chunkOverlap?: number; + }) { + this.textSplitter = + init?.textSplitter ?? + new SentenceSplitter( + init?.chunkSize ?? DEFAULT_CHUNK_SIZE, + init?.chunkOverlap ?? DEFAULT_CHUNK_OVERLAP + ); + this.includeMetadata = init?.includeMetadata ?? true; + this.includePrevNextRel = init?.includePrevNextRel ?? true; } - static fromDefaults(): SimpleNodeParser { - return new SimpleNodeParser(); + static fromDefaults(init?: { + chunkSize?: number; + chunkOverlap?: number; + includeMetadata?: boolean; + includePrevNextRel?: boolean; + }): SimpleNodeParser { + return new SimpleNodeParser(init); } /** diff --git a/packages/core/src/OutputParser.ts b/packages/core/src/OutputParser.ts new file mode 100644 index 0000000000000000000000000000000000000000..a5e0dded714507b64b8ff61de5d168673fd0fbd4 --- /dev/null +++ b/packages/core/src/OutputParser.ts @@ -0,0 +1,80 @@ +import { SubQuestion } from "./QuestionGenerator"; + +export interface BaseOutputParser<T> { + parse(output: string): T; + format(output: string): string; +} + +export interface StructuredOutput<T> { + rawOutput: string; + parsedOutput: T; +} + +class OutputParserError extends Error { + cause: Error | undefined; + output: string | undefined; + + constructor( + message: string, + options: { cause?: Error; output?: string } = {} + ) { + // @ts-ignore + super(message, options); // https://github.com/tc39/proposal-error-cause + this.name = "OutputParserError"; + + if (!this.cause) { + // Need to check for those environments that have implemented the proposal + this.cause = options.cause; + } + this.output = options.output; + + // This line is to maintain proper stack trace in V8 + // (https://v8.dev/docs/stack-trace-api) + if (Error.captureStackTrace) { + Error.captureStackTrace(this, OutputParserError); + } + } +} + +function parseJsonMarkdown(text: string) { + text = text.trim(); + + const beginDelimiter = "```json"; + const endDelimiter = "```"; + + const beginIndex = text.indexOf(beginDelimiter); + const endIndex = text.indexOf( + endDelimiter, + beginIndex + beginDelimiter.length + ); + if (beginIndex === -1 || endIndex === -1) { + throw new OutputParserError("Not a json markdown", { output: text }); + } + + const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex); + + try { + return JSON.parse(jsonText); + } catch (e) { + throw new OutputParserError("Not a valid json", { + cause: e as Error, + output: text, + }); + } +} + +export class SubQuestionOutputParser + implements BaseOutputParser<StructuredOutput<SubQuestion[]>> +{ + parse(output: string): StructuredOutput<SubQuestion[]> { + const parsed = parseJsonMarkdown(output); + + // TODO add zod validation + + return { rawOutput: output, parsedOutput: parsed }; + } + + format(output: string): string { + return output; + } +} diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index baa2f3f0a7d0fb9d52830b06f534c8b944132e31..3cfb6fac4a04c032963863482804447427ab5aa4 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,3 +1,7 @@ +import { BaseMessage } from "./LanguageModel"; +import { SubQuestion } from "./QuestionGenerator"; +import { ToolMetadata } from "./Tool"; + /** * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. * NOTE this is a different interface compared to LlamaIndex Python @@ -80,3 +84,236 @@ ${context} ------------ Given the new context, refine the original answer to better answer the question. If the context isn't useful, return the original answer.`; }; + +export const defaultChoiceSelectPrompt: SimplePrompt = (input) => { + const { context = "", query = "" } = input; + + return `A list of documents is shown below. Each document has a number next to it along +with a summary of the document. A question is also provided. +Respond with the numbers of the documents +you should consult to answer the question, in order of relevance, as well +as the relevance score. The relevance score is a number from 1-10 based on +how relevant you think the document is to the question. +Do not include any documents that are not relevant to the question. +Example format: +Document 1: +<summary of document 1> + +Document 2: +<summary of document 2> + +... + +Document 10:\n<summary of document 10> + +Question: <question> +Answer: +Doc: 9, Relevance: 7 +Doc: 3, Relevance: 4 +Doc: 7, Relevance: 3 + +Let's try this now: + +${context} +Question: ${query} +Answer:`; +}; + +/* +PREFIX = """\ +Given a user question, and a list of tools, output a list of relevant sub-questions \ +that when composed can help answer the full user question: + +""" + + +example_query_str = ( + "Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021" +) +example_tools = [ + ToolMetadata( + name="uber_10k", + description="Provides information about Uber financials for year 2021", + ), + ToolMetadata( + name="lyft_10k", + description="Provides information about Lyft financials for year 2021", + ), +] +example_tools_str = build_tools_text(example_tools) +example_output = [ + SubQuestion( + sub_question="What is the revenue growth of Uber", tool_name="uber_10k" + ), + SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"), + SubQuestion( + sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k" + ), + SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"), +] +example_output_str = json.dumps([x.dict() for x in example_output], indent=4) + +EXAMPLES = ( + """\ +# Example 1 +<Tools> +```json +{tools_str} +``` + +<User Question> +{query_str} + + +<Output> +```json +{output_str} +``` + +""".format( + query_str=example_query_str, + tools_str=example_tools_str, + output_str=example_output_str, + ) + .replace("{", "{{") + .replace("}", "}}") +) + +SUFFIX = """\ +# Example 2 +<Tools> +```json +{tools_str} +``` + +<User Question> +{query_str} + +<Output> +""" + +DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX +*/ + +export function buildToolsText(tools: ToolMetadata[]) { + const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { + acc[tool.name] = tool.description; + return acc; + }, {}); + + return JSON.stringify(toolsObj, null, 4); +} + +const exampleTools: ToolMetadata[] = [ + { + name: "uber_10k", + description: "Provides information about Uber financials for year 2021", + }, + { + name: "lyft_10k", + description: "Provides information about Lyft financials for year 2021", + }, +]; + +const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`; + +const exampleOutput: SubQuestion[] = [ + { + subQuestion: "What is the revenue growth of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the EBITDA of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the revenue growth of Lyft", + toolName: "lyft_10k", + }, + { + subQuestion: "What is the EBITDA of Lyft", + toolName: "lyft_10k", + }, +]; + +export const defaultSubQuestionPrompt: SimplePrompt = (input) => { + const { toolsStr, queryStr } = input; + + return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: + +# Example 1 +<Tools> +\`\`\`json +${buildToolsText(exampleTools)} +\`\`\` + +<User Question> +${exampleQueryStr} + +<Output> +\`\`\`json +${JSON.stringify(exampleOutput, null, 4)} +\`\`\` + +# Example 2 +<Tools> +\`\`\`json +${toolsStr}} +\`\`\` + +<User Question> +${queryStr} + +<Output> +`; +}; + +// DEFAULT_TEMPLATE = """\ +// Given a conversation (between Human and Assistant) and a follow up message from Human, \ +// rewrite the message to be a standalone question that captures all relevant context \ +// from the conversation. + +// <Chat History> +// {chat_history} + +// <Follow Up Message> +// {question} + +// <Standalone question> +// """ + +export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => { + const { chatHistory, question } = input; + + return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation. + +<Chat History> +${chatHistory} + +<Follow Up Message> +${question} + +<Standalone question> +`; +}; + +export function messagesToHistoryStr(messages: BaseMessage[]) { + return messages.reduce((acc, message) => { + acc += acc ? "\n" : ""; + if (message.type === "human") { + acc += `Human: ${message.content}`; + } else { + acc += `Assistant: ${message.content}`; + } + return acc; + }, ""); +} + +export const contextSystemPrompt: SimplePrompt = (input) => { + const { context } = input; + + return `Context information is below. +--------------------- +${context} +---------------------`; +}; diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index 1f57d3601cc959e6385eb024eb2dd0e5107f57fa..3c622f0a4d58bcd30fff113f4b46694d65d70de4 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,15 +1,22 @@ +import { NodeWithScore, TextNode } from "./Node"; +import { + BaseQuestionGenerator, + LLMQuestionGenerator, + SubQuestion, +} from "./QuestionGenerator"; import { Response } from "./Response"; -import { ResponseSynthesizer } from "./ResponseSynthesizer"; +import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; -import { ServiceContext } from "./ServiceContext"; import { v4 as uuidv4 } from "uuid"; -import { Trace } from "./callbacks/CallbackManager"; +import { Event } from "./callbacks/CallbackManager"; +import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { QueryEngineTool, ToolMetadata } from "./Tool"; export interface BaseQueryEngine { - aquery(query: string): Promise<Response>; + aquery(query: string, parentEvent?: Event): Promise<Response>; } -export class RetrieverQueryEngine { +export class RetrieverQueryEngine implements BaseQueryEngine { retriever: BaseRetriever; responseSynthesizer: ResponseSynthesizer; @@ -17,14 +24,113 @@ export class RetrieverQueryEngine { this.retriever = retriever; const serviceContext: ServiceContext | undefined = this.retriever.getServiceContext(); - this.responseSynthesizer = new ResponseSynthesizer(serviceContext); + this.responseSynthesizer = new ResponseSynthesizer({ serviceContext }); } - async aquery(query: string) { - const parentTrace: Trace = { + async aquery(query: string, parentEvent?: Event) { + const _parentEvent: Event = parentEvent || { id: uuidv4(), + type: "wrapper", + tags: ["final"], }; - const nodes = await this.retriever.aretrieve(query, parentTrace); - return this.responseSynthesizer.asynthesize(query, nodes, parentTrace); + const nodes = await this.retriever.aretrieve(query, _parentEvent); + return this.responseSynthesizer.asynthesize(query, nodes, parentEvent); + } +} + +export class SubQuestionQueryEngine implements BaseQueryEngine { + responseSynthesizer: ResponseSynthesizer; + questionGen: BaseQuestionGenerator; + queryEngines: Record<string, BaseQueryEngine>; + metadatas: ToolMetadata[]; + + constructor(init: { + questionGen: BaseQuestionGenerator; + responseSynthesizer: ResponseSynthesizer; + queryEngineTools: QueryEngineTool[]; + }) { + this.questionGen = init.questionGen; + this.responseSynthesizer = + init.responseSynthesizer ?? new ResponseSynthesizer(); + this.queryEngines = init.queryEngineTools.reduce< + Record<string, BaseQueryEngine> + >((acc, tool) => { + acc[tool.metadata.name] = tool.queryEngine; + return acc; + }, {}); + this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); + } + + static fromDefaults(init: { + queryEngineTools: QueryEngineTool[]; + questionGen?: BaseQuestionGenerator; + responseSynthesizer?: ResponseSynthesizer; + serviceContext?: ServiceContext; + }) { + const serviceContext = + init.serviceContext ?? serviceContextFromDefaults({}); + + const questionGen = init.questionGen ?? new LLMQuestionGenerator(); + const responseSynthesizer = + init.responseSynthesizer ?? + new ResponseSynthesizer({ + responseBuilder: new CompactAndRefine(serviceContext), + serviceContext, + }); + + return new SubQuestionQueryEngine({ + questionGen, + responseSynthesizer, + queryEngineTools: init.queryEngineTools, + }); + } + + async aquery(query: string): Promise<Response> { + const subQuestions = await this.questionGen.agenerate( + this.metadatas, + query + ); + + // groups final retrieval+synthesis operation + const parentEvent: Event = { + id: uuidv4(), + type: "wrapper", + tags: ["final"], + }; + + // groups all sub-queries + const subQueryParentEvent: Event = { + id: uuidv4(), + parentId: parentEvent.id, + type: "wrapper", + tags: ["intermediate"], + }; + + const subQNodes = await Promise.all( + subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent)) + ); + + const nodes = subQNodes + .filter((node) => node !== null) + .map((node) => node as NodeWithScore); + return this.responseSynthesizer.asynthesize(query, nodes, parentEvent); + } + + private async aquerySubQ( + subQ: SubQuestion, + parentEvent?: Event + ): Promise<NodeWithScore | null> { + try { + const question = subQ.subQuestion; + const queryEngine = this.queryEngines[subQ.toolName]; + + const response = await queryEngine.aquery(question, parentEvent); + const responseText = response.response; + const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`; + const node = new TextNode({ text: nodeText }); + return { node, score: 0 }; + } catch (error) { + return null; + } } } diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts new file mode 100644 index 0000000000000000000000000000000000000000..fad1f9732c952095034e7b07ba22ea6e03b2ebf0 --- /dev/null +++ b/packages/core/src/QuestionGenerator.ts @@ -0,0 +1,49 @@ +import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; +import { + BaseOutputParser, + StructuredOutput, + SubQuestionOutputParser, +} from "./OutputParser"; +import { + SimplePrompt, + buildToolsText, + defaultSubQuestionPrompt, +} from "./Prompt"; +import { ToolMetadata } from "./Tool"; + +export interface SubQuestion { + subQuestion: string; + toolName: string; +} + +export interface BaseQuestionGenerator { + agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; +} + +export class LLMQuestionGenerator implements BaseQuestionGenerator { + llmPredictor: BaseLLMPredictor; + prompt: SimplePrompt; + outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; + + constructor(init?: Partial<LLMQuestionGenerator>) { + this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor(); + this.prompt = init?.prompt ?? defaultSubQuestionPrompt; + this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); + } + + async agenerate( + tools: ToolMetadata[], + query: string + ): Promise<SubQuestion[]> { + const toolsStr = buildToolsText(tools); + const queryStr = query; + const prediction = await this.llmPredictor.apredict(this.prompt, { + toolsStr, + queryStr, + }); + + const structuredOutput = this.outputParser.parse(prediction); + + return structuredOutput.parsedOutput; + } +} diff --git a/packages/core/src/Response.ts b/packages/core/src/Response.ts index 9ef1b81e4e2f4774d57954a9db4bd4d48c736323..6351ac90fe89ed72aea12158958118054b23e0a0 100644 --- a/packages/core/src/Response.ts +++ b/packages/core/src/Response.ts @@ -1,10 +1,10 @@ import { BaseNode } from "./Node"; export class Response { - response?: string; - sourceNodes: BaseNode[]; + response: string; + sourceNodes?: BaseNode[]; - constructor(response?: string, sourceNodes?: BaseNode[]) { + constructor(response: string, sourceNodes?: BaseNode[]) { this.response = response; this.sourceNodes = sourceNodes || []; } diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index 2404a585c011f98a15b450b15581ea24290db208..4aeedb55e7ed19e8e4cd3eab0fa850e089aad3fd 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -8,10 +8,14 @@ import { import { getBiggestPrompt } from "./PromptHelper"; import { Response } from "./Response"; import { ServiceContext } from "./ServiceContext"; -import { Trace } from "./callbacks/CallbackManager"; +import { Event } from "./callbacks/CallbackManager"; interface BaseResponseBuilder { - agetResponse(query: string, textChunks: string[]): Promise<string>; + agetResponse( + query: string, + textChunks: string[], + parentEvent?: Event + ): Promise<string>; } export class SimpleResponseBuilder implements BaseResponseBuilder { @@ -27,7 +31,7 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { async agetResponse( query: string, textChunks: string[], - parentTrace?: Trace + parentEvent?: Event ): Promise<string> { const input = { query, @@ -35,7 +39,7 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { }; const prompt = this.textQATemplate(input); - return this.llmPredictor.apredict(prompt, {}, parentTrace); + return this.llmPredictor.apredict(prompt, {}, parentEvent); } } @@ -190,19 +194,27 @@ export function getResponseBuilder( return new SimpleResponseBuilder(serviceContext); } +// TODO replace with Logan's new response_sythesizers/factory.py export class ResponseSynthesizer { - responseBuilder: SimpleResponseBuilder; + responseBuilder: BaseResponseBuilder; serviceContext?: ServiceContext; - constructor(serviceContext?: ServiceContext) { + constructor({ + responseBuilder, + serviceContext, + }: { + responseBuilder?: BaseResponseBuilder; + serviceContext?: ServiceContext; + } = {}) { this.serviceContext = serviceContext; - this.responseBuilder = getResponseBuilder(this.serviceContext); + this.responseBuilder = + responseBuilder ?? getResponseBuilder(this.serviceContext); } async asynthesize( query: string, nodes: NodeWithScore[], - parentTrace?: Trace + parentEvent?: Event ) { let textChunks: string[] = nodes.map((node) => node.node.getContent(MetadataMode.NONE) @@ -210,7 +222,7 @@ export class ResponseSynthesizer { const response = await this.responseBuilder.agetResponse( query, textChunks, - parentTrace + parentEvent ); return new Response( response, diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index b63cda9a9fd36dbdda90e041aed524bee891f4f1..428be6bb68e36dee6d02c44d0f733a82a5ec364b 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -2,7 +2,7 @@ import { VectorStoreIndex } from "./BaseIndex"; import { globalsHelper } from "./GlobalsHelper"; import { NodeWithScore } from "./Node"; import { ServiceContext } from "./ServiceContext"; -import { Trace } from "./callbacks/CallbackManager"; +import { Event } from "./callbacks/CallbackManager"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; import { VectorStoreQuery, @@ -10,7 +10,7 @@ import { } from "./storage/vectorStore/types"; export interface BaseRetriever { - aretrieve(query: string, parentTrace?: Trace): Promise<any>; + aretrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>; getServiceContext(): ServiceContext; } @@ -26,7 +26,7 @@ export class VectorIndexRetriever implements BaseRetriever { async aretrieve( query: string, - parentTrace?: Trace + parentEvent?: Event ): Promise<NodeWithScore[]> { const queryEmbedding = await this.serviceContext.embedModel.aGetQueryEmbedding(query); @@ -51,7 +51,10 @@ export class VectorIndexRetriever implements BaseRetriever { this.serviceContext.callbackManager.onRetrieve({ query, nodes: nodesWithScores, - trace: globalsHelper.createTrace({ parentTrace }), + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), }); } diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 8d3eb72d42ab66492d821763c30b768c6be129e5..58570afd168789d5045757a433b4d3f2f14af2c0 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -26,15 +26,20 @@ export interface ServiceContextOptions { chunkOverlap?: number; } -export function serviceContextFromDefaults(options: ServiceContextOptions) { - const callbackManager = options.callbackManager ?? new CallbackManager(); +export function serviceContextFromDefaults(options?: ServiceContextOptions) { + const callbackManager = options?.callbackManager ?? new CallbackManager(); const serviceContext: ServiceContext = { llmPredictor: - options.llmPredictor ?? - new ChatGPTLLMPredictor({ callbackManager, languageModel: options.llm }), - embedModel: options.embedModel ?? new OpenAIEmbedding(), - nodeParser: options.nodeParser ?? new SimpleNodeParser(), - promptHelper: options.promptHelper ?? new PromptHelper(), + options?.llmPredictor ?? + new ChatGPTLLMPredictor({ callbackManager, languageModel: options?.llm }), + embedModel: options?.embedModel ?? new OpenAIEmbedding(), + nodeParser: + options?.nodeParser ?? + new SimpleNodeParser({ + chunkSize: options?.chunkSize, + chunkOverlap: options?.chunkOverlap, + }), + promptHelper: options?.promptHelper ?? new PromptHelper(), callbackManager, }; diff --git a/packages/core/src/Tool.ts b/packages/core/src/Tool.ts new file mode 100644 index 0000000000000000000000000000000000000000..5eaecb125cacd0e45848607229eadbdb9d4ed48b --- /dev/null +++ b/packages/core/src/Tool.ts @@ -0,0 +1,14 @@ +import { BaseQueryEngine } from "./QueryEngine"; + +export interface ToolMetadata { + description: string; + name: string; +} + +export interface BaseTool { + metadata: ToolMetadata; +} + +export interface QueryEngineTool extends BaseTool { + queryEngine: BaseQueryEngine; +} diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts index 0a2f35f9a23a7b03f55d0a2694911a7fb2c2d6d5..a35e061c10a313b18a3503566d81d2dbdb04a6d0 100644 --- a/packages/core/src/callbacks/CallbackManager.ts +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -2,18 +2,23 @@ import { ChatCompletionResponseMessageRoleEnum } from "openai"; import { NodeWithScore } from "../Node"; /* - A trace is a wrapper that allows grouping - related operations. For example, during retrieve and synthesize, - a parent trace wraps both operations, and each operation has it's own - trace. In this case, both operations will share a parentId. + An event is a wrapper that groups related operations. + For example, during retrieve and synthesize, + a parent event wraps both operations, and each operation has it's own + event. In this case, both sub-events will share a parentId. */ -export interface Trace { + +export type EventTag = "intermediate" | "final"; +export type EventType = "retrieve" | "llmPredict" | "wrapper"; +export interface Event { id: string; + type: EventType; + tags?: EventTag[]; parentId?: string; } interface BaseCallbackResponse { - trace: Trace; + event: Event; } export interface StreamToken { diff --git a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts b/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts index 9f01d6345b71f5ff83c4891eecf6851f939f0dad..b477806086ff09c75f55f4eb3e71ac33ba3c82ed 100644 --- a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts +++ b/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts @@ -1,17 +1,20 @@ import { globalsHelper } from "../../GlobalsHelper"; -import { StreamCallbackResponse, Trace } from "../CallbackManager"; +import { StreamCallbackResponse, Event } from "../CallbackManager"; import { StreamToken } from "../CallbackManager"; export async function aHandleOpenAIStream({ response, onLLMStream, - parentTrace, + parentEvent, }: { response: any; onLLMStream: (data: StreamCallbackResponse) => void; - parentTrace?: Trace; + parentEvent?: Event; }): Promise<string> { - const trace = globalsHelper.createTrace({ parentTrace }); + const event = globalsHelper.createEvent({ + parentEvent, + type: "llmPredict", + }); const stream = __astreamCompletion(response.data as any); let index = 0; let cumulativeText = ""; @@ -23,10 +26,10 @@ export async function aHandleOpenAIStream({ continue; } cumulativeText += content; - onLLMStream?.({ trace, index, token }); + onLLMStream?.({ event, index, token }); index++; } - onLLMStream?.({ trace, index, isDone: true }); + onLLMStream?.({ event, index, isDone: true }); return cumulativeText; } diff --git a/packages/core/src/index/list/ListIndex.ts b/packages/core/src/index/list/ListIndex.ts new file mode 100644 index 0000000000000000000000000000000000000000..56a17e0fe07c4d8348c3e503b89b4b77bddd0f65 --- /dev/null +++ b/packages/core/src/index/list/ListIndex.ts @@ -0,0 +1,166 @@ +import { BaseNode, Document } from "../../Node"; +import { BaseIndex, BaseIndexInit, IndexList } from "../../BaseIndex"; +import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; +import { + StorageContext, + storageContextFromDefaults, +} from "../../storage/StorageContext"; +import { BaseRetriever } from "../../Retriever"; +import { ListIndexRetriever } from "./ListIndexRetriever"; +import { + ServiceContext, + serviceContextFromDefaults, +} from "../../ServiceContext"; +import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; +import _ from "lodash"; + +export enum ListRetrieverMode { + DEFAULT = "default", + // EMBEDDING = "embedding", + LLM = "llm", +} + +export interface ListIndexOptions { + nodes?: BaseNode[]; + indexStruct?: IndexList; + serviceContext?: ServiceContext; + storageContext?: StorageContext; +} + +export class ListIndex extends BaseIndex<IndexList> { + constructor(init: BaseIndexInit<IndexList>) { + super(init); + } + + static async init(options: ListIndexOptions): Promise<ListIndex> { + const storageContext = + options.storageContext ?? (await storageContextFromDefaults({})); + const serviceContext = + options.serviceContext ?? serviceContextFromDefaults({}); + const { docStore, indexStore } = storageContext; + + let indexStruct: IndexList; + if (options.indexStruct) { + if (options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex with both nodes and indexStruct" + ); + } + indexStruct = options.indexStruct; + } else { + if (!options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex without nodes or indexStruct" + ); + } + indexStruct = ListIndex._buildIndexFromNodes( + options.nodes, + storageContext.docStore + ); + } + + return new ListIndex({ + storageContext, + serviceContext, + docStore, + indexStore, + indexStruct, + }); + } + + static async fromDocuments( + documents: Document[], + storageContext?: StorageContext, + serviceContext?: ServiceContext + ): Promise<ListIndex> { + storageContext = storageContext ?? (await storageContextFromDefaults({})); + serviceContext = serviceContext ?? serviceContextFromDefaults({}); + const docStore = storageContext.docStore; + + docStore.addDocuments(documents, true); + for (const doc of documents) { + docStore.setDocumentHash(doc.id_, doc.hash); + } + + const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents); + const index = await ListIndex.init({ + nodes, + storageContext, + serviceContext, + }); + return index; + } + + asRetriever( + mode: ListRetrieverMode = ListRetrieverMode.DEFAULT + ): BaseRetriever { + switch (mode) { + case ListRetrieverMode.DEFAULT: + return new ListIndexRetriever(this); + case ListRetrieverMode.LLM: + throw new Error(`Support for LLM retriever mode is not implemented`); + default: + throw new Error(`Unknown retriever mode: ${mode}`); + } + } + + asQueryEngine( + mode: ListRetrieverMode = ListRetrieverMode.DEFAULT + ): BaseQueryEngine { + return new RetrieverQueryEngine(this.asRetriever()); + } + + static _buildIndexFromNodes( + nodes: BaseNode[], + docStore: BaseDocumentStore, + indexStruct?: IndexList + ): IndexList { + indexStruct = indexStruct || new IndexList(); + + docStore.addDocuments(nodes, true); + for (const node of nodes) { + indexStruct.addNode(node); + } + + return indexStruct; + } + + protected _insert(nodes: BaseNode[]): void { + for (const node of nodes) { + this.indexStruct.addNode(node); + } + } + + protected _deleteNode(nodeId: string): void { + this.indexStruct.nodes = this.indexStruct.nodes.filter( + (existingNodeId: string) => existingNodeId !== nodeId + ); + } + + async getRefDocInfo(): Promise<Record<string, RefDocInfo>> { + const nodeDocIds = this.indexStruct.nodes; + const nodes = await this.docStore.getNodes(nodeDocIds); + + const refDocInfoMap: Record<string, RefDocInfo> = {}; + + for (const node of nodes) { + const refNode = node.sourceNode; + if (_.isNil(refNode)) { + continue; + } + + const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId); + + if (_.isNil(refDocInfo)) { + continue; + } + + refDocInfoMap[refNode.nodeId] = refDocInfo; + } + + return refDocInfoMap; + } +} + +// Legacy +export type GPTListIndex = ListIndex; diff --git a/packages/core/src/index/list/ListIndexRetriever.ts b/packages/core/src/index/list/ListIndexRetriever.ts new file mode 100644 index 0000000000000000000000000000000000000000..15b6d9c2e88a063a91167f78bb73a2061837fe25 --- /dev/null +++ b/packages/core/src/index/list/ListIndexRetriever.ts @@ -0,0 +1,137 @@ +import { BaseRetriever } from "../../Retriever"; +import { NodeWithScore } from "../../Node"; +import { ListIndex } from "./ListIndex"; +import { ServiceContext } from "../../ServiceContext"; +import { + NodeFormatterFunction, + ChoiceSelectParserFunction, + defaultFormatNodeBatchFn, + defaultParseChoiceSelectAnswerFn, +} from "./utils"; +import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; +import _ from "lodash"; +import { globalsHelper } from "../../GlobalsHelper"; +import { Event } from "../../callbacks/CallbackManager"; + +/** + * Simple retriever for ListIndex that returns all nodes + */ +export class ListIndexRetriever implements BaseRetriever { + index: ListIndex; + + constructor(index: ListIndex) { + this.index = index; + } + + async aretrieve( + query: string, + parentEvent?: Event + ): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const nodes = await this.index.docStore.getNodes(nodeIds); + const result = nodes.map((node) => ({ + node: node, + score: 1, + })); + + if (this.index.serviceContext.callbackManager.onRetrieve) { + this.index.serviceContext.callbackManager.onRetrieve({ + query, + nodes: result, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + + return result; + } + + getServiceContext(): ServiceContext { + return this.index.serviceContext; + } +} + +/** + * LLM retriever for ListIndex. + */ +export class ListIndexLLMRetriever implements BaseRetriever { + index: ListIndex; + choiceSelectPrompt: SimplePrompt; + choiceBatchSize: number; + formatNodeBatchFn: NodeFormatterFunction; + parseChoiceSelectAnswerFn: ChoiceSelectParserFunction; + serviceContext: ServiceContext; + + constructor( + index: ListIndex, + choiceSelectPrompt?: SimplePrompt, + choiceBatchSize: number = 10, + formatNodeBatchFn?: NodeFormatterFunction, + parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction, + serviceContext?: ServiceContext + ) { + this.index = index; + this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt; + this.choiceBatchSize = choiceBatchSize; + this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn; + this.parseChoiceSelectAnswerFn = + parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn; + this.serviceContext = serviceContext || index.serviceContext; + } + + async aretrieve( + query: string, + parentEvent?: Event + ): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const results: NodeWithScore[] = []; + + for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) { + const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize); + const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch); + + const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); + const input = { context: fmtBatchStr, query: query }; + const rawResponse = await this.serviceContext.llmPredictor.apredict( + this.choiceSelectPrompt, + input + ); + + // parseResult is a map from doc number to relevance score + const parseResult = this.parseChoiceSelectAnswerFn( + rawResponse, + nodesBatch.length + ); + const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => { + return `${idx}` in parseResult; + }); + + const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds); + const nodeWithScores = choiceNodes.map((node, i) => ({ + node: node, + score: _.get(parseResult, `${i + 1}`, 1), + })); + + results.push(...nodeWithScores); + } + + if (this.serviceContext.callbackManager.onRetrieve) { + this.serviceContext.callbackManager.onRetrieve({ + query, + nodes: results, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + + return results; + } + + getServiceContext(): ServiceContext { + return this.serviceContext; + } +} diff --git a/packages/core/src/index/list/index.ts b/packages/core/src/index/list/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..f8d0b8d5eae44cf561fd8483fbabf6dc716260d0 --- /dev/null +++ b/packages/core/src/index/list/index.ts @@ -0,0 +1,5 @@ +export { ListIndex, ListRetrieverMode } from "./ListIndex"; +export { + ListIndexRetriever, + ListIndexLLMRetriever, +} from "./ListIndexRetriever"; diff --git a/packages/core/src/index/list/utils.ts b/packages/core/src/index/list/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..b7a1d3f8fddc0af2c587ef1db2c8ff0db01dbfce --- /dev/null +++ b/packages/core/src/index/list/utils.ts @@ -0,0 +1,73 @@ +import { BaseNode, MetadataMode } from "../../Node"; +import _ from "lodash"; + +export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string; +export const defaultFormatNodeBatchFn: NodeFormatterFunction = ( + summaryNodes: BaseNode[] +): string => { + return summaryNodes + .map((node, idx) => { + return ` +Document ${idx + 1}: +${node.getContent(MetadataMode.LLM)} + `.trim(); + }) + .join("\n\n"); +}; + +// map from document number to its relevance score +export type ChoiceSelectParseResult = { [docNumber: number]: number }; +export type ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr?: boolean +) => ChoiceSelectParseResult; + +export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr: boolean = false +): ChoiceSelectParseResult => { + // split the line into the answer number and relevance score portions + const lineTokens: string[][] = answer + .split("\n") + .map((line: string) => { + let lineTokens = line.split(","); + if (lineTokens.length !== 2) { + if (raiseErr) { + throw new Error( + `Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>` + ); + } else { + return null; + } + } + return lineTokens; + }) + .filter((lineTokens) => !_.isNil(lineTokens)) as string[][]; + + // parse the answer number and relevance score + return lineTokens.reduce( + (parseResult: ChoiceSelectParseResult, lineToken: string[]) => { + try { + let docNum = parseInt(lineToken[0].split(":")[1].trim()); + let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim()); + if (docNum < 1 || docNum > numChoices) { + if (raiseErr) { + throw new Error( + `Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}` + ); + } else { + parseResult[docNum] = answerRelevance; + } + } + } catch (e) { + if (raiseErr) { + throw e; + } + } + return parseResult; + }, + {} + ); +}; diff --git a/packages/core/src/storage/docStore/KVDocumentStore.ts b/packages/core/src/storage/docStore/KVDocumentStore.ts index 64b9780d762ce260226364747a5a39fd6a19057d..027672e6d276a3eac279a0aba37e684d905aab80 100644 --- a/packages/core/src/storage/docStore/KVDocumentStore.ts +++ b/packages/core/src/storage/docStore/KVDocumentStore.ts @@ -77,7 +77,7 @@ export class KVDocumentStore extends BaseDocumentStore { let json = await this.kvstore.get(docId, this.nodeCollection); if (_.isNil(json)) { if (raiseError) { - throw new Error(`doc_id ${docId} not found.`); + throw new Error(`docId ${docId} not found.`); } else { return; } diff --git a/packages/core/src/storage/docStore/utils.ts b/packages/core/src/storage/docStore/utils.ts index 8c80a3c875d7a3e14fbc62463cddcc9fd26c571b..a7329df67e14d10343234e49f6538c324b452641 100644 --- a/packages/core/src/storage/docStore/utils.ts +++ b/packages/core/src/storage/docStore/utils.ts @@ -23,12 +23,11 @@ export function jsonToDoc(docDict: Record<string, any>): BaseNode { hash: dataDict.hash, }); } else if (docType === ObjectType.TEXT) { - const relationships = dataDict.relationships; + console.log({ dataDict }); doc = new TextNode({ - text: relationships.text, - id_: relationships.id_, - embedding: relationships.embedding, - hash: relationships.hash, + text: dataDict.text, + id_: dataDict.id_, + hash: dataDict.hash, }); } else { throw new Error(`Unknown doc type: ${docType}`); diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index b1f2f3ce198e6fc8c62e092564ff3538bc3610ce..dfedc2c3984b920aa8c5e3798fb5dd66ff5c2871 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -8,6 +8,7 @@ import { RetrievalCallbackResponse, StreamCallbackResponse, } from "../callbacks/CallbackManager"; +import { ListIndex } from "../index/list"; import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI"; // Mock the OpenAI getOpenAISession function during testing @@ -18,7 +19,6 @@ jest.mock("../openai", () => { }); describe("CallbackManager: onLLMStream and onRetrieve", () => { - let vectorStoreIndex: VectorStoreIndex; let serviceContext: ServiceContext; let streamCallbackData: StreamCallbackResponse[] = []; let retrieveCallbackData: RetrievalCallbackResponse[] = []; @@ -49,12 +49,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { llm: languageModel, embedModel, }); - - vectorStoreIndex = await VectorStoreIndex.fromDocuments( - [document], - undefined, - serviceContext - ); }); beforeEach(() => { @@ -67,15 +61,90 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }); test("For VectorStoreIndex w/ a SimpleResponseBuilder", async () => { + const vectorStoreIndex = await VectorStoreIndex.fromDocuments( + [document], + undefined, + serviceContext + ); const queryEngine = vectorStoreIndex.asQueryEngine(); const query = "What is the author's name?"; const response = await queryEngine.aquery(query); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { - trace: { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + }, + index: 0, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + }, + index: 1, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + }, + index: 2, + isDone: true, + }, + ]); + expect(retrieveCallbackData).toEqual([ + { + query: query, + nodes: expect.any(Array), + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "retrieve", + }, + }, + ]); + // both retrieval and streaming should have + // the same parent event + expect(streamCallbackData[0].event.parentId).toBe( + retrieveCallbackData[0].event.parentId + ); + }); + + test("For ListIndex w/ a ListIndexRetriever", async () => { + const listIndex = await ListIndex.fromDocuments( + [document], + undefined, + serviceContext + ); + const queryEngine = listIndex.asQueryEngine(); + const query = "What is the author's name?"; + const response = await queryEngine.aquery(query); + expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); + expect(streamCallbackData).toEqual([ + { + event: { id: expect.any(String), parentId: expect.any(String), + type: "llmPredict", }, index: 0, token: { @@ -87,9 +156,10 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - trace: { + event: { id: expect.any(String), parentId: expect.any(String), + type: "llmPredict", }, index: 1, token: { @@ -101,9 +171,10 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - trace: { + event: { id: expect.any(String), parentId: expect.any(String), + type: "llmPredict", }, index: 2, isDone: true, @@ -113,16 +184,17 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { { query: query, nodes: expect.any(Array), - trace: { + event: { id: expect.any(String), parentId: expect.any(String), + type: "retrieve", }, }, ]); // both retrieval and streaming should have - // the same parent trace - expect(streamCallbackData[0].trace.parentId).toBe( - retrieveCallbackData[0].trace.parentId + // the same parent event + expect(streamCallbackData[0].event.parentId).toBe( + retrieveCallbackData[0].event.parentId ); }); }); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 16bc3fcd7851a292f811595a222ece2f45e69045..67631a9acded857a2b483dce1826484977d78aa7 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -1,7 +1,7 @@ import { OpenAIEmbedding } from "../../Embedding"; import { globalsHelper } from "../../GlobalsHelper"; import { BaseMessage, ChatOpenAI } from "../../LanguageModel"; -import { CallbackManager, Trace } from "../../callbacks/CallbackManager"; +import { CallbackManager, Event } from "../../callbacks/CallbackManager"; export function mockLlmGeneration({ languageModel, @@ -13,15 +13,18 @@ export function mockLlmGeneration({ jest .spyOn(languageModel, "agenerate") .mockImplementation( - async (messages: BaseMessage[], parentTrace?: Trace) => { + async (messages: BaseMessage[], parentEvent?: Event) => { const text = "MOCK_TOKEN_1-MOCK_TOKEN_2"; - const trace = globalsHelper.createTrace({ parentTrace }); + const event = globalsHelper.createEvent({ + parentEvent, + type: "llmPredict", + }); if (callbackManager?.onLLMStream) { const chunks = text.split("-"); for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; callbackManager?.onLLMStream({ - trace, + event, index: i, token: { id: "id", @@ -41,7 +44,7 @@ export function mockLlmGeneration({ }); } callbackManager?.onLLMStream({ - trace, + event, index: chunks.length, isDone: true, });