diff --git a/.changeset/olive-wolves-own.md b/.changeset/olive-wolves-own.md new file mode 100644 index 0000000000000000000000000000000000000000..63094ccfda09d984125ed4c2cbeb506da129c6ad --- /dev/null +++ b/.changeset/olive-wolves-own.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat(reranker): cohere reranker diff --git a/apps/docs/docs/modules/data_loader.mdx b/apps/docs/docs/modules/data_loader.mdx index 8fa39dbf12c4cf5aa5df05cb361a674e194bfb51..8392d2bc0e482b278d363429a57e5ba2f418a4af 100644 --- a/apps/docs/docs/modules/data_loader.mdx +++ b/apps/docs/docs/modules/data_loader.mdx @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- import CodeBlock from "@theme/CodeBlock"; diff --git a/apps/docs/docs/modules/embedding.md b/apps/docs/docs/modules/embedding.md index 8be55616ace4e45c6d36718fec9d0a9e7878791c..ef8ddb4f081c3755b3c370091ac2ee80f4e93ad2 100644 --- a/apps/docs/docs/modules/embedding.md +++ b/apps/docs/docs/modules/embedding.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # Embedding diff --git a/apps/docs/docs/modules/node_parser.md b/apps/docs/docs/modules/node_parser.md index c5679fafa208f4953b97d25eb17e4a18263667b7..9bf3fc304e68e40a01952fd583e661c5e6027a8d 100644 --- a/apps/docs/docs/modules/node_parser.md +++ b/apps/docs/docs/modules/node_parser.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # NodeParser diff --git a/apps/docs/docs/modules/node_postprocessors/_category_.yml b/apps/docs/docs/modules/node_postprocessors/_category_.yml new file mode 100644 index 0000000000000000000000000000000000000000..88891a4dc6811583f4349e242e6758a524f74288 --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/_category_.yml @@ -0,0 +1,2 @@ +label: "Node Postprocessors" +position: 3 diff --git a/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md b/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md new file mode 100644 index 0000000000000000000000000000000000000000..717f8a2cdbc71e08341f100b2268830789468e40 --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md @@ -0,0 +1,71 @@ +# Cohere Reranker + +The Cohere Reranker is a postprocessor that uses the Cohere API to rerank the results of a search query. + +## Setup + +Firstly, you will need to install the `llamaindex` package. + +```bash +pnpm install llamaindex +``` + +Now, you will need to sign up for an API key at [Cohere](https://cohere.ai/). Once you have your API key you can import the necessary modules and create a new instance of the `CohereRerank` class. + +```ts +import { + CohereRerank, + Document, + OpenAI, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; +``` + +## Load and index documents + +For this example, we will use a single document. In a real-world scenario, you would have multiple documents to index. + +```ts +const document = new Document({ text: essay, id_: "essay" }); + +const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), +}); + +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); +``` + +## Increase similarity topK to retrieve more results + +The default value for `similarityTopK` is 2. This means that only the most similar document will be returned. To retrieve more results, you can increase the value of `similarityTopK`. + +```ts +const retriever = index.asRetriever(); +retriever.similarityTopK = 5; +``` + +## Create a new instance of the CohereRerank class + +Then you can create a new instance of the `CohereRerank` class and pass in your API key and the number of results you want to return. + +```ts +const nodePostprocessor = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 4, +}); +``` + +## Create a query engine with the retriever and node postprocessor + +```ts +const queryEngine = index.asQueryEngine({ + retriever, + nodePostprocessors: [nodePostprocessor], +}); + +// log the response +const response = await queryEngine.query("Where did the author grown up?"); +``` diff --git a/apps/docs/docs/modules/node_postprocessors/index.md b/apps/docs/docs/modules/node_postprocessors/index.md new file mode 100644 index 0000000000000000000000000000000000000000..ec5ae88e1f07764cab14143fe3f7f20ad093e9b6 --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/index.md @@ -0,0 +1,110 @@ +# Node Postprocessors + +## Concept + +Node postprocessors are a set of modules that take a set of nodes, and apply some kind of transformation or filtering before returning them. + +In LlamaIndex, node postprocessors are most commonly applied within a query engine, after the node retrieval step and before the response synthesis step. + +LlamaIndex offers several node postprocessors for immediate use, while also providing a simple API for adding your own custom postprocessors. + +## Usage Pattern + +An example of using a node postprocessors is below: + +```ts +import { + Node, + NodeWithScore, + SimilarityPostprocessor, + CohereRerank, +} from "llamaindex"; + +const nodes: NodeWithScore[] = [ + { + node: new TextNode({ text: "hello world" }), + score: 0.8, + }, + { + node: new TextNode({ text: "LlamaIndex is the best" }), + score: 0.6, + }, +]; + +// similarity postprocessor: filter nodes below 0.75 similarity score +const processor = new SimilarityPostprocessor({ + similarityCutoff: 0.7, +}); + +const filteredNodes = processor.postprocessNodes(nodes); + +// cohere rerank: rerank nodes given query using trained model +const reranker = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 2, +}); + +const rerankedNodes = await reranker.postprocessNodes(nodes, "<user_query>"); + +console.log(filteredNodes, rerankedNodes); +``` + +Now you can use the `filteredNodes` and `rerankedNodes` in your application. + +## Using Node Postprocessors in LlamaIndex + +Most commonly, node-postprocessors will be used in a query engine, where they are applied to the nodes returned from a retriever, and before the response synthesis step. + +### Using Node Postprocessors in a Query Engine + +```ts +import { Node, NodeWithScore, SimilarityPostprocessor, CohereRerank } from "llamaindex"; + +const nodes: NodeWithScore[] = [ + { + node: new TextNode({ text: "hello world" }), + score: 0.8, + }, + { + node: new TextNode({ text: "LlamaIndex is the best" }), + score: 0.6, + } +]; + +// cohere rerank: rerank nodes given query using trained model +const reranker = new CohereRerank({ + apiKey: "<COHERE_API_KEY>, + topN: 2, +}) + +const document = new Document({ text: "essay", id_: "essay" }); + +const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), +}); + +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +const queryEngine = index.asQueryEngine({ + nodePostprocessors: [processor, reranker], +}); + +// all node post-processors will be applied during each query +const response = await queryEngine.query("<user_query>"); +``` + +### Using with retrieved nodes + +```ts +import { SimilarityPostprocessor } from "llamaindex"; + +nodes = await index.asRetriever().retrieve("test query str"); + +const processor = new SimilarityPostprocessor({ + similarityCutoff: 0.7, +}); + +const filteredNodes = processor.postprocessNodes(nodes); +``` diff --git a/examples/rerankers/CohereReranker.ts b/examples/rerankers/CohereReranker.ts new file mode 100644 index 0000000000000000000000000000000000000000..4b34fa549e577defde713aed834cff15656989cb --- /dev/null +++ b/examples/rerankers/CohereReranker.ts @@ -0,0 +1,55 @@ +import { + CohereRerank, + Document, + OpenAI, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay, id_: "essay" }); + + const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), + }); + + const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, + }); + + const retriever = index.asRetriever(); + + retriever.similarityTopK = 5; + + const nodePostprocessor = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 5, + }); + + const queryEngine = index.asQueryEngine({ + retriever, + nodePostprocessors: [nodePostprocessor], + }); + + const baseQueryEngine = index.asQueryEngine({ + retriever, + }); + + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); + + // cohere response + console.log(response.response); + + const baseResponse = await baseQueryEngine.query({ + query: "What did the author do growing up?", + }); + + // response without cohere + console.log(baseResponse.response); +} + +main().catch(console.error); diff --git a/packages/core/package.json b/packages/core/package.json index 9e5719e6e456f6c12df562ec727fa4089a0362bf..ff403290252f787f9be6215ded66f51db041c855 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -13,6 +13,7 @@ "@xenova/transformers": "^2.15.0", "assemblyai": "^4.2.2", "chromadb": "~1.7.3", + "cohere-ai": "^7.7.5", "file-type": "^18.7.0", "js-tiktoken": "^1.0.10", "lodash": "^4.17.21", diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index b5db4d27b64696cff91faf65c79d0cba64fd468c..11566b85dd6b896f0a4c6c40e7efd97aab1212a8 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -22,11 +22,17 @@ export class DefaultContextGenerator implements ContextGenerator { this.nodePostprocessors = init.nodePostprocessors || []; } - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + let nodesWithScore = nodes; + + for (const postprocessor of this.nodePostprocessors) { + nodesWithScore = await postprocessor.postprocessNodes( + nodesWithScore, + query, + ); + } + + return nodesWithScore; } async generate(message: string, parentEvent?: Event): Promise<Context> { @@ -42,7 +48,10 @@ export class DefaultContextGenerator implements ContextGenerator { parentEvent, ); - const nodes = this.applyNodePostprocessors(sourceNodesWithScore); + const nodes = await this.applyNodePostprocessors( + sourceNodesWithScore, + message, + ); return { message: { diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index 51470a387e4a1a64afb3f49b0012042387358b19..cbec0f7e3d7cae19ecfca811e173419f79eb5c9e 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -36,11 +36,17 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.nodePostprocessors = nodePostprocessors || []; } - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + let nodesWithScore = nodes; + + for (const postprocessor of this.nodePostprocessors) { + nodesWithScore = await postprocessor.postprocessNodes( + nodesWithScore, + query, + ); + } + + return nodesWithScore; } private async retrieve(query: string, parentEvent: Event) { @@ -50,7 +56,7 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.preFilters, ); - return this.applyNodePostprocessors(nodes); + return await this.applyNodePostprocessors(nodes, query); } query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts index d05d94a108f66f5a4ab388a45614e99e9ed5f523..e53a6e43593f854c9f2f9d3783276b1ca9535f63 100644 --- a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts +++ b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts @@ -8,7 +8,7 @@ export class MetadataReplacementPostProcessor implements BaseNodePostprocessor { this.targetMetadataKey = targetMetadataKey; } - postprocessNodes(nodes: NodeWithScore[]): NodeWithScore[] { + async postprocessNodes(nodes: NodeWithScore[]): Promise<NodeWithScore[]> { for (let n of nodes) { n.node.setContent( n.node.metadata[this.targetMetadataKey] ?? diff --git a/packages/core/src/postprocessors/SimilarityPostprocessor.ts b/packages/core/src/postprocessors/SimilarityPostprocessor.ts index 91674515ed7cbe67eeb23ba3c579af51a57f3fa9..413fa922952a431be338cc0a45b19543347e4db4 100644 --- a/packages/core/src/postprocessors/SimilarityPostprocessor.ts +++ b/packages/core/src/postprocessors/SimilarityPostprocessor.ts @@ -8,7 +8,7 @@ export class SimilarityPostprocessor implements BaseNodePostprocessor { this.similarityCutoff = options?.similarityCutoff; } - postprocessNodes(nodes: NodeWithScore[]) { + async postprocessNodes(nodes: NodeWithScore[]) { if (this.similarityCutoff === undefined) return nodes; const cutoff = this.similarityCutoff || 0; diff --git a/packages/core/src/postprocessors/index.ts b/packages/core/src/postprocessors/index.ts index f79e4ced0526bd211536df82353e40f00eaa10a9..3e52ee40c1ca93b19e545a95ec251008e96d414d 100644 --- a/packages/core/src/postprocessors/index.ts +++ b/packages/core/src/postprocessors/index.ts @@ -1,3 +1,4 @@ export * from "./MetadataReplacementPostProcessor"; export * from "./SimilarityPostprocessor"; +export * from "./rerankers"; export * from "./types"; diff --git a/packages/core/src/postprocessors/rerankers/CohereRerank.ts b/packages/core/src/postprocessors/rerankers/CohereRerank.ts new file mode 100644 index 0000000000000000000000000000000000000000..44d4e9f7a2d29f1fb82cb18c51975fcb48112240 --- /dev/null +++ b/packages/core/src/postprocessors/rerankers/CohereRerank.ts @@ -0,0 +1,82 @@ +import { CohereClient } from "cohere-ai"; + +import { MetadataMode, NodeWithScore } from "../../Node"; +import { BaseNodePostprocessor } from "../types"; + +type CohereRerankOptions = { + topN?: number; + model?: string; + apiKey: string | null; +}; + +export class CohereRerank implements BaseNodePostprocessor { + topN: number = 2; + model: string = "rerank-english-v2.0"; + apiKey: string | null = null; + + private client: CohereClient | null = null; + + /** + * Constructor for CohereRerank. + * @param topN Number of nodes to return. + */ + constructor({ + topN = 2, + model = "rerank-english-v2.0", + apiKey = null, + }: CohereRerankOptions) { + if (apiKey === null) { + throw new Error("CohereRerank requires an API key"); + } + + this.topN = topN; + this.model = model; + this.apiKey = apiKey; + + this.client = new CohereClient({ + token: this.apiKey, + }); + } + + /** + * Reranks the nodes using the Cohere API. + * @param nodes Array of nodes with scores. + * @param query Query string. + */ + async postprocessNodes( + nodes: NodeWithScore[], + query?: string, + ): Promise<NodeWithScore[]> { + if (this.client === null) { + throw new Error("CohereRerank client is null"); + } + + if (nodes.length === 0) { + return []; + } + + if (query === undefined) { + throw new Error("CohereRerank requires a query"); + } + + const results = await this.client.rerank({ + query, + model: this.model, + topN: this.topN, + documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), + }); + + const newNodes: NodeWithScore[] = []; + + for (const result of results.results) { + const node = nodes[result.index]; + + newNodes.push({ + node: node.node, + score: result.relevanceScore, + }); + } + + return newNodes; + } +} diff --git a/packages/core/src/postprocessors/rerankers/index.ts b/packages/core/src/postprocessors/rerankers/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..8ef81bdb87e098e191e8c1e8cbedcd7a7822ddff --- /dev/null +++ b/packages/core/src/postprocessors/rerankers/index.ts @@ -0,0 +1 @@ +export * from "./CohereRerank"; diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts index 2d0c73e78431a939646f57d30038c8a835b915f9..deedc0ceeedff5f82dd850cb1ea90115c591e935 100644 --- a/packages/core/src/postprocessors/types.ts +++ b/packages/core/src/postprocessors/types.ts @@ -1,5 +1,14 @@ import { NodeWithScore } from "../Node"; export interface BaseNodePostprocessor { - postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[]; + /** + * Send message along with the class's current chat history to the LLM. + * This version returns a promise for asynchronous operation. + * @param nodes Array of nodes with scores. + * @param query Optional query string. + */ + postprocessNodes( + nodes: NodeWithScore[], + query?: string, + ): Promise<NodeWithScore[]>; } diff --git a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts index f9a76845db7e0bbb1e8314df11c5f85b4bc1f16d..1e6f16c9d647648eb48f12c4b74185d533f6adf5 100644 --- a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts +++ b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts @@ -18,15 +18,15 @@ describe("MetadataReplacementPostProcessor", () => { ]; }); - test("Replaces the content of each node with specified metadata key if it exists", () => { + test("Replaces the content of each node with specified metadata key if it exists", async () => { nodes[0].node.metadata = { targetKey: "NewContent" }; - const newNodes = postProcessor.postprocessNodes(nodes); + const newNodes = await postProcessor.postprocessNodes(nodes); // Check if node content was replaced correctly expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("NewContent"); }); - test("Retains the original content of each node if no metadata key is found", () => { - const newNodes = postProcessor.postprocessNodes(nodes); + test("Retains the original content of each node if no metadata key is found", async () => { + const newNodes = await postProcessor.postprocessNodes(nodes); // Check if node content remained unchanged expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("OldContent"); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8aa514fc79a022e44a072042429258587414fc09..e27fc0c2e22f85e2270dd07e72463e5a09881b8c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -194,7 +194,10 @@ importers: version: 4.2.2 chromadb: specifier: ~1.7.3 - version: 1.7.3(openai@4.26.1) + version: 1.7.3(cohere-ai@7.7.5)(openai@4.26.1) + cohere-ai: + specifier: ^7.7.5 + version: 7.7.5 file-type: specifier: ^18.7.0 version: 18.7.0 @@ -5966,7 +5969,7 @@ packages: engines: {node: '>=10'} dev: true - /chromadb@1.7.3(openai@4.26.1): + /chromadb@1.7.3(cohere-ai@7.7.5)(openai@4.26.1): resolution: {integrity: sha512-3GgvQjpqgk5C89x5EuTDaXKbfrdqYDJ5UVyLQ3ZmwxnpetNc+HhRDGjkvXa5KSvpQ3lmKoyDoqnN4tZepfFkbw==} engines: {node: '>=14.17.0'} peerDependencies: @@ -5982,6 +5985,7 @@ packages: optional: true dependencies: cliui: 8.0.1 + cohere-ai: 7.7.5 isomorphic-fetch: 3.0.0 openai: 4.26.1 transitivePeerDependencies: @@ -6130,6 +6134,18 @@ packages: rfdc: 1.3.1 dev: false + /cohere-ai@7.7.5: + resolution: {integrity: sha512-uKh4TzHpY/8nwuYprLKSz0vSKclL4zb8il/YNxfFuWgk9/Nhuw4ugzRuT+CE2f4miQG5dPtPfwXCYUJB60mfwQ==} + dependencies: + form-data: 4.0.0 + js-base64: 3.7.2 + node-fetch: 2.7.0(encoding@0.1.13) + qs: 6.11.2 + url-join: 4.0.1 + transitivePeerDependencies: + - encoding + dev: false + /collapse-white-space@2.1.0: resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==} @@ -10276,6 +10292,10 @@ packages: '@sideway/formula': 3.0.1 '@sideway/pinpoint': 2.0.0 + /js-base64@3.7.2: + resolution: {integrity: sha512-NnRs6dsyqUXejqk/yv2aiXlAvOs56sLkX6nUdeaNezI5LFFLlsZjOThmwnrcwh5ZZRwZlCMnVAY3CvhIhoVEKQ==} + dev: false + /js-tiktoken@1.0.10: resolution: {integrity: sha512-ZoSxbGjvGyMT13x6ACo9ebhDha/0FHdKA+OsQcMOWcm1Zs7r90Rhk5lhERLzji+3rA7EKpXCgwXcM5fF3DMpdA==} dependencies: @@ -13311,6 +13331,13 @@ packages: dependencies: side-channel: 1.0.4 + /qs@6.11.2: + resolution: {integrity: sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==} + engines: {node: '>=0.6'} + dependencies: + side-channel: 1.0.4 + dev: false + /queue-microtask@1.2.3: resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} @@ -15824,6 +15851,10 @@ packages: dependencies: punycode: 2.3.1 + /url-join@4.0.1: + resolution: {integrity: sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==} + dev: false + /url-loader@4.1.1(file-loader@6.2.0)(webpack@5.90.0): resolution: {integrity: sha512-3BTV812+AVHHOJQO8O5MkWgZ5aosP7GnROJwvzLS9hWDj00lZ6Z0wNak423Lp9PBZN05N+Jk/N5Si8jRAlGyWA==} engines: {node: '>= 10.13.0'}