diff --git a/.changeset/tall-kids-prove.md b/.changeset/tall-kids-prove.md new file mode 100644 index 0000000000000000000000000000000000000000..4ec59907521e79dd83782745d67f44cfa556a22e --- /dev/null +++ b/.changeset/tall-kids-prove.md @@ -0,0 +1,6 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +--- + +feat: asChatEngine function for index diff --git a/examples/chat-engine/keyword-index.ts b/examples/chat-engine/keyword-index.ts new file mode 100644 index 0000000000000000000000000000000000000000..8e933c086b7cff3b57b4cfe6b7c3e1361e2331f0 --- /dev/null +++ b/examples/chat-engine/keyword-index.ts @@ -0,0 +1,15 @@ +import { Document, KeywordTableIndex } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await KeywordTableIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine(); + + const response = await chatEngine.chat({ + message: "What is Harsh Mistress?", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/examples/chat-engine/summary-index.ts b/examples/chat-engine/summary-index.ts new file mode 100644 index 0000000000000000000000000000000000000000..034575adee40c248f06aff8ccd7e184a73a6a723 --- /dev/null +++ b/examples/chat-engine/summary-index.ts @@ -0,0 +1,17 @@ +import { Document, SummaryIndex, SummaryRetrieverMode } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await SummaryIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine({ + mode: SummaryRetrieverMode.LLM, + }); + + const response = await chatEngine.chat({ + message: "Summary about the author", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/examples/chat-engine/vector-store-index.ts b/examples/chat-engine/vector-store-index.ts new file mode 100644 index 0000000000000000000000000000000000000000..011612668139587f69b5d0ad857e25f7f16a23f7 --- /dev/null +++ b/examples/chat-engine/vector-store-index.ts @@ -0,0 +1,15 @@ +import { Document, VectorStoreIndex } from "llamaindex"; +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + const chatEngine = index.asChatEngine({ similarityTopK: 5 }); + + const response = await chatEngine.chat({ + message: "What did I work on in February 2021?", + }); + console.log(response.message.content); +} + +main().catch(console.error); diff --git a/packages/core/src/chat-engine/context-chat-engine.ts b/packages/core/src/chat-engine/context-chat-engine.ts index 7f260158d7f928415df862f6c305362ad0b46ae8..b1f1ff75ea0b896e3ca3fd0e894bc18fd3672f15 100644 --- a/packages/core/src/chat-engine/context-chat-engine.ts +++ b/packages/core/src/chat-engine/context-chat-engine.ts @@ -20,6 +20,16 @@ import type { import { DefaultContextGenerator } from "./default-context-generator"; import type { ContextGenerator } from "./type"; +export type ContextChatEngineOptions = { + retriever: BaseRetriever; + chatModel?: LLM | undefined; + chatHistory?: ChatMessage[] | undefined; + contextSystemPrompt?: ContextSystemPrompt | undefined; + nodePostprocessors?: BaseNodePostprocessor[] | undefined; + systemPrompt?: string | undefined; + contextRole?: MessageType | undefined; +}; + /** * ContextChatEngine uses the Index to get the appropriate context for each query. * The context is stored in the system prompt, and the chat history is chunk, @@ -35,15 +45,7 @@ export class ContextChatEngine extends PromptMixin implements BaseChatEngine { return this.memory.getMessages(); } - constructor(init: { - retriever: BaseRetriever; - chatModel?: LLM | undefined; - chatHistory?: ChatMessage[] | undefined; - contextSystemPrompt?: ContextSystemPrompt | undefined; - nodePostprocessors?: BaseNodePostprocessor[] | undefined; - systemPrompt?: string | undefined; - contextRole?: MessageType | undefined; - }) { + constructor(init: ContextChatEngineOptions) { super(); this.chatModel = init.chatModel ?? Settings.llm; this.memory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory }); diff --git a/packages/core/src/chat-engine/index.ts b/packages/core/src/chat-engine/index.ts index f5af4dd4d8e99ca35c3318ac0e89befa47309f40..c52aa9b3aee2c16bf3179b80bd7e531e4233fdb2 100644 --- a/packages/core/src/chat-engine/index.ts +++ b/packages/core/src/chat-engine/index.ts @@ -4,6 +4,9 @@ export { type NonStreamingChatEngineParams, type StreamingChatEngineParams, } from "./base"; -export { ContextChatEngine } from "./context-chat-engine"; +export { + ContextChatEngine, + type ContextChatEngineOptions, +} from "./context-chat-engine"; export { DefaultContextGenerator } from "./default-context-generator"; export { SimpleChatEngine } from "./simple-chat-engine"; diff --git a/packages/llamaindex/src/indices/BaseIndex.ts b/packages/llamaindex/src/indices/BaseIndex.ts index e95fdf749c8c0fa18778fe5249968c81ad47da00..9e14deb2726046c3fc975bb77c6766b9e8020d8e 100644 --- a/packages/llamaindex/src/indices/BaseIndex.ts +++ b/packages/llamaindex/src/indices/BaseIndex.ts @@ -1,3 +1,7 @@ +import type { + BaseChatEngine, + ContextChatEngineOptions, +} from "@llamaindex/core/chat-engine"; import type { BaseQueryEngine } from "@llamaindex/core/query-engine"; import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers"; import type { BaseRetriever } from "@llamaindex/core/retriever"; @@ -53,6 +57,14 @@ export abstract class BaseIndex<T> { responseSynthesizer?: BaseSynthesizer; }): BaseQueryEngine; + /** + * Create a new chat engine from the index. + * @param options + */ + abstract asChatEngine( + options?: Omit<ContextChatEngineOptions, "retriever">, + ): BaseChatEngine; + /** * Insert a document into the index. * @param document diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index 7c17c6a1deb64d3c3ba8819c04915c72a55d33af..ec44118fcb4009b26c177c8de776e4290b58ff83 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -35,6 +35,11 @@ import { BaseRetriever } from "@llamaindex/core/retriever"; import type { BaseDocumentStore } from "@llamaindex/core/storage/doc-store"; import { extractText } from "@llamaindex/core/utils"; import { llmFromSettingsOrContext } from "../../Settings.js"; +import { + ContextChatEngine, + type BaseChatEngine, + type ContextChatEngineOptions, +} from "../../engines/chat/index.js"; export interface KeywordIndexOptions { nodes?: BaseNode[]; @@ -152,6 +157,10 @@ const KeywordTableRetrieverMap = { [KeywordTableRetrieverMode.RAKE]: KeywordTableRAKERetriever, }; +export type KeywordTableIndexChatEngineOptions = { + retriever?: BaseRetriever; +} & Omit<ContextChatEngineOptions, "retriever">; + /** * The KeywordTableIndex, an index that extracts keywords from each Node and builds a mapping from each keyword to the corresponding Nodes of that keyword. */ @@ -251,6 +260,14 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { ); } + asChatEngine(options?: KeywordTableIndexChatEngineOptions): BaseChatEngine { + const { retriever, ...contextChatEngineOptions } = options ?? {}; + return new ContextChatEngine({ + retriever: retriever ?? this.asRetriever(), + ...contextChatEngineOptions, + }); + } + static async extractKeywords( text: string, serviceContext?: ServiceContext, diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index a3fa20444a371c903ca2f2ab572c266b941a0cc2..aa31576713d8d076e3bec2f77152cad747e4fad1 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -24,6 +24,11 @@ import { llmFromSettingsOrContext, nodeParserFromSettingsOrContext, } from "../../Settings.js"; +import type { + BaseChatEngine, + ContextChatEngineOptions, +} from "../../engines/chat/index.js"; +import { ContextChatEngine } from "../../engines/chat/index.js"; import { RetrieverQueryEngine } from "../../engines/query/index.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -44,6 +49,11 @@ export enum SummaryRetrieverMode { LLM = "llm", } +export type SummaryIndexChatEngineOptions = { + retriever?: BaseRetriever; + mode?: SummaryRetrieverMode; +} & Omit<ContextChatEngineOptions, "retriever">; + export interface SummaryIndexOptions { nodes?: BaseNode[] | undefined; indexStruct?: IndexList | undefined; @@ -193,6 +203,16 @@ export class SummaryIndex extends BaseIndex<IndexList> { ); } + asChatEngine(options?: SummaryIndexChatEngineOptions): BaseChatEngine { + const { retriever, mode, ...contextChatEngineOptions } = options ?? {}; + return new ContextChatEngine({ + retriever: + retriever ?? + this.asRetriever({ mode: mode ?? SummaryRetrieverMode.DEFAULT }), + ...contextChatEngineOptions, + }); + } + static async buildIndexFromNodes( nodes: BaseNode[], docStore: BaseDocumentStore, diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index bdbafd62c4bf97e2aac2a2391144b024468f6b67..6b21f7ae6eca5f5043a0a61a954d3e7f0d80a731 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -1,3 +1,7 @@ +import { + ContextChatEngine, + type ContextChatEngineOptions, +} from "@llamaindex/core/chat-engine"; import { IndexDict, IndexStructType } from "@llamaindex/core/data-structs"; import { DEFAULT_SIMILARITY_TOP_K, @@ -59,6 +63,12 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { vectorStores?: VectorStoreByType | undefined; } +export type VectorIndexChatEngineOptions = { + retriever?: BaseRetriever; + similarityTopK?: number; + preFilters?: MetadataFilters; +} & Omit<ContextChatEngineOptions, "retriever">; + /** * The VectorStoreIndex, an index that stores the nodes only according to their vector embeddings. */ @@ -309,6 +319,25 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ); } + /** + * Convert the index to a chat engine. + * @param options The options for creating the chat engine + * @returns A ContextChatEngine that uses the index's retriever to get context for each query + */ + asChatEngine(options: VectorIndexChatEngineOptions = {}) { + const { + retriever, + similarityTopK, + preFilters, + ...contextChatEngineOptions + } = options; + return new ContextChatEngine({ + retriever: + retriever ?? this.asRetriever({ similarityTopK, filters: preFilters }), + ...contextChatEngineOptions, + }); + } + protected async insertNodesToStore( newIds: string[], nodes: BaseNode[],