From d924c63162ce4ec564664415008e42a2b9428c35 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:57:15 +0700 Subject: [PATCH] feat: asChatEngine function for index (#1640) --- .changeset/tall-kids-prove.md | 6 ++++ examples/chat-engine/keyword-index.ts | 15 ++++++++++ examples/chat-engine/summary-index.ts | 17 +++++++++++ examples/chat-engine/vector-store-index.ts | 15 ++++++++++ .../src/chat-engine/context-chat-engine.ts | 20 +++++++------ packages/core/src/chat-engine/index.ts | 5 +++- packages/llamaindex/src/indices/BaseIndex.ts | 12 ++++++++ .../llamaindex/src/indices/keyword/index.ts | 17 +++++++++++ .../llamaindex/src/indices/summary/index.ts | 20 +++++++++++++ .../src/indices/vectorStore/index.ts | 29 +++++++++++++++++++ 10 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 .changeset/tall-kids-prove.md create mode 100644 examples/chat-engine/keyword-index.ts create mode 100644 examples/chat-engine/summary-index.ts create mode 100644 examples/chat-engine/vector-store-index.ts diff --git a/.changeset/tall-kids-prove.md b/.changeset/tall-kids-prove.md new file mode 100644 index 000000000..4ec599075 --- /dev/null +++ b/.changeset/tall-kids-prove.md @@ -0,0 +1,6 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +--- + +feat: asChatEngine function for index diff --git a/examples/chat-engine/keyword-index.ts b/examples/chat-engine/keyword-index.ts new file mode 100644 index 000000000..8e933c086 --- /dev/null +++ b/examples/chat-engine/keyword-index.ts @@ -0,0 +1,15 @@ +import { Document, KeywordTableIndex } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await KeywordTableIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine(); + + const response = await chatEngine.chat({ + message: "What is Harsh Mistress?", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/examples/chat-engine/summary-index.ts b/examples/chat-engine/summary-index.ts new file mode 100644 index 000000000..034575ade --- /dev/null +++ b/examples/chat-engine/summary-index.ts @@ -0,0 +1,17 @@ +import { Document, SummaryIndex, SummaryRetrieverMode } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await SummaryIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine({ + mode: SummaryRetrieverMode.LLM, + }); + + const response = await chatEngine.chat({ + message: "Summary about the author", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/examples/chat-engine/vector-store-index.ts b/examples/chat-engine/vector-store-index.ts new file mode 100644 index 000000000..011612668 --- /dev/null +++ b/examples/chat-engine/vector-store-index.ts @@ -0,0 +1,15 @@ +import { Document, VectorStoreIndex } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine({ similarityTopK: 5 }); + + const response = await chatEngine.chat({ + message: "What did I work on in February 2021?", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/packages/core/src/chat-engine/context-chat-engine.ts b/packages/core/src/chat-engine/context-chat-engine.ts index 7f260158d..b1f1ff75e 100644 --- a/packages/core/src/chat-engine/context-chat-engine.ts +++ b/packages/core/src/chat-engine/context-chat-engine.ts @@ -20,6 +20,16 @@ import type { import { DefaultContextGenerator } from "./default-context-generator"; import type { ContextGenerator } from "./type"; +export type ContextChatEngineOptions = { + retriever: BaseRetriever; + chatModel?: LLM | undefined; + chatHistory?: ChatMessage[] | undefined; + contextSystemPrompt?: ContextSystemPrompt | undefined; + nodePostprocessors?: BaseNodePostprocessor[] | undefined; + systemPrompt?: string | undefined; + contextRole?: MessageType | undefined; +}; + /** * ContextChatEngine uses the Index to get the appropriate context for each query. * The context is stored in the system prompt, and the chat history is chunk, @@ -35,15 +45,7 @@ export class ContextChatEngine extends PromptMixin implements BaseChatEngine { return this.memory.getMessages(); } - constructor(init: { - retriever: BaseRetriever; - chatModel?: LLM | undefined; - chatHistory?: ChatMessage[] | undefined; - contextSystemPrompt?: ContextSystemPrompt | undefined; - nodePostprocessors?: BaseNodePostprocessor[] | undefined; - systemPrompt?: string | undefined; - contextRole?: MessageType | undefined; - }) { + constructor(init: ContextChatEngineOptions) { super(); this.chatModel = init.chatModel ?? Settings.llm; this.memory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory }); diff --git a/packages/core/src/chat-engine/index.ts b/packages/core/src/chat-engine/index.ts index f5af4dd4d..c52aa9b3a 100644 --- a/packages/core/src/chat-engine/index.ts +++ b/packages/core/src/chat-engine/index.ts @@ -4,6 +4,9 @@ export { type NonStreamingChatEngineParams, type StreamingChatEngineParams, } from "./base"; -export { ContextChatEngine } from "./context-chat-engine"; +export { + ContextChatEngine, + type ContextChatEngineOptions, +} from "./context-chat-engine"; export { DefaultContextGenerator } from "./default-context-generator"; export { SimpleChatEngine } from "./simple-chat-engine"; diff --git a/packages/llamaindex/src/indices/BaseIndex.ts b/packages/llamaindex/src/indices/BaseIndex.ts index e95fdf749..9e14deb27 100644 --- a/packages/llamaindex/src/indices/BaseIndex.ts +++ b/packages/llamaindex/src/indices/BaseIndex.ts @@ -1,3 +1,7 @@ +import type { + BaseChatEngine, + ContextChatEngineOptions, +} from "@llamaindex/core/chat-engine"; import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; import type { BaseRetriever } from "@llamaindex/core/retriever"; @@ -53,6 +57,14 @@ export abstract class BaseIndex<T> { responseSynthesizer?: BaseSynthesizer; }): BaseQueryEngine; + /** + * Create a new chat engine from the index. + * @param options + */ + abstract asChatEngine( + options?: Omit<ContextChatEngineOptions, "retriever">, + ): BaseChatEngine; + /** * Insert a document into the index. * @param document diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index 7c17c6a1d..ec44118fc 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -35,6 +35,11 @@ import { BaseRetriever } from "@llamaindex/core/retriever"; import type { BaseDocumentStore } from "@llamaindex/core/storage/doc-store"; import { extractText } from "@llamaindex/core/utils"; import { llmFromSettingsOrContext } from "../../Settings.js"; +import { + ContextChatEngine, + type BaseChatEngine, + type ContextChatEngineOptions, +} from "../../engines/chat/index.js"; export interface KeywordIndexOptions { nodes?: BaseNode[]; @@ -152,6 +157,10 @@ const KeywordTableRetrieverMap = { [KeywordTableRetrieverMode.RAKE]: KeywordTableRAKERetriever, }; +export type KeywordTableIndexChatEngineOptions = { + retriever?: BaseRetriever; +} & Omit<ContextChatEngineOptions, "retriever">; + /** * The KeywordTableIndex, an index that extracts keywords from each Node and builds a mapping from each keyword to the corresponding Nodes of that keyword. */ @@ -251,6 +260,14 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { ); } + asChatEngine(options?: KeywordTableIndexChatEngineOptions): BaseChatEngine { + const { retriever, ...contextChatEngineOptions } = options ?? {}; + return new ContextChatEngine({ + retriever: retriever ?? this.asRetriever(), + ...contextChatEngineOptions, + }); + } + static async extractKeywords( text: string, serviceContext?: ServiceContext, diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index a3fa20444..aa3157671 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -24,6 +24,11 @@ import { llmFromSettingsOrContext, nodeParserFromSettingsOrContext, } from "../../Settings.js"; +import type { + BaseChatEngine, + ContextChatEngineOptions, +} from "../../engines/chat/index.js"; +import { ContextChatEngine } from "../../engines/chat/index.js"; import { RetrieverQueryEngine } from "../../engines/query/index.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -44,6 +49,11 @@ export enum SummaryRetrieverMode { LLM = "llm", } +export type SummaryIndexChatEngineOptions = { + retriever?: BaseRetriever; + mode?: SummaryRetrieverMode; +} & Omit<ContextChatEngineOptions, "retriever">; + export interface SummaryIndexOptions { nodes?: BaseNode[] | undefined; indexStruct?: IndexList | undefined; @@ -193,6 +203,16 @@ export class SummaryIndex extends BaseIndex<IndexList> { ); } + asChatEngine(options?: SummaryIndexChatEngineOptions): BaseChatEngine { + const { retriever, mode, ...contextChatEngineOptions } = options ?? {}; + return new ContextChatEngine({ + retriever: + retriever ?? + this.asRetriever({ mode: mode ?? SummaryRetrieverMode.DEFAULT }), + ...contextChatEngineOptions, + }); + } + static async buildIndexFromNodes( nodes: BaseNode[], docStore: BaseDocumentStore, diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index bdbafd62c..6b21f7ae6 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -1,3 +1,7 @@ +import { + ContextChatEngine, + type ContextChatEngineOptions, +} from "@llamaindex/core/chat-engine"; import { IndexDict, IndexStructType } from "@llamaindex/core/data-structs"; import { DEFAULT_SIMILARITY_TOP_K, @@ -59,6 +63,12 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { vectorStores?: VectorStoreByType | undefined; } +export type VectorIndexChatEngineOptions = { + retriever?: BaseRetriever; + similarityTopK?: number; + preFilters?: MetadataFilters; +} & Omit<ContextChatEngineOptions, "retriever">; + /** * The VectorStoreIndex, an index that stores the nodes only according to their vector embeddings. */ @@ -309,6 +319,25 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ); } + /** + * Convert the index to a chat engine. + * @param options The options for creating the chat engine + * @returns A ContextChatEngine that uses the index's retriever to get context for each query + */ + asChatEngine(options: VectorIndexChatEngineOptions = {}) { + const { + retriever, + similarityTopK, + preFilters, + ...contextChatEngineOptions + } = options; + return new ContextChatEngine({ + retriever: + retriever ?? this.asRetriever({ similarityTopK, filters: preFilters }), + ...contextChatEngineOptions, + }); + } + protected async insertNodesToStore( newIds: string[], nodes: BaseNode[], -- GitLab