From e78e9f483206d441eeda662655d543f703b25dea Mon Sep 17 00:00:00 2001 From: Emanuel Ferreira <contatoferreirads@gmail.com> Date: Sat, 10 Feb 2024 12:07:14 -0300 Subject: [PATCH] feat(reranker): cohere reranker (#535) --- .changeset/olive-wolves-own.md | 5 + apps/docs/docs/modules/data_loader.mdx | 2 +- apps/docs/docs/modules/embedding.md | 2 +- apps/docs/docs/modules/node_parser.md | 2 +- .../node_postprocessors/_category_.yml | 2 + .../node_postprocessors/cohere_reranker.md | 71 +++++++++++ .../docs/modules/node_postprocessors/index.md | 110 ++++++++++++++++++ examples/rerankers/CohereReranker.ts | 55 +++++++++ packages/core/package.json | 1 + .../engines/chat/DefaultContextGenerator.ts | 21 +++- .../src/engines/query/RetrieverQueryEngine.ts | 18 ++- .../MetadataReplacementPostProcessor.ts | 2 +- .../postprocessors/SimilarityPostprocessor.ts | 2 +- packages/core/src/postprocessors/index.ts | 1 + .../postprocessors/rerankers/CohereRerank.ts | 82 +++++++++++++ .../src/postprocessors/rerankers/index.ts | 1 + packages/core/src/postprocessors/types.ts | 11 +- .../MetadataReplacementPostProcessor.test.ts | 8 +- pnpm-lock.yaml | 35 +++++- 19 files changed, 407 insertions(+), 24 deletions(-) create mode 100644 .changeset/olive-wolves-own.md create mode 100644 apps/docs/docs/modules/node_postprocessors/_category_.yml create mode 100644 apps/docs/docs/modules/node_postprocessors/cohere_reranker.md create mode 100644 apps/docs/docs/modules/node_postprocessors/index.md create mode 100644 examples/rerankers/CohereReranker.ts create mode 100644 packages/core/src/postprocessors/rerankers/CohereRerank.ts create mode 100644 packages/core/src/postprocessors/rerankers/index.ts diff --git a/.changeset/olive-wolves-own.md b/.changeset/olive-wolves-own.md new file mode 100644 index 000000000..63094ccfd --- /dev/null +++ b/.changeset/olive-wolves-own.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat(reranker): cohere reranker diff --git a/apps/docs/docs/modules/data_loader.mdx b/apps/docs/docs/modules/data_loader.mdx index 8fa39dbf1..8392d2bc0 100644 --- a/apps/docs/docs/modules/data_loader.mdx +++ b/apps/docs/docs/modules/data_loader.mdx @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- import CodeBlock from "@theme/CodeBlock"; diff --git a/apps/docs/docs/modules/embedding.md b/apps/docs/docs/modules/embedding.md index 8be55616a..ef8ddb4f0 100644 --- a/apps/docs/docs/modules/embedding.md +++ b/apps/docs/docs/modules/embedding.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # Embedding diff --git a/apps/docs/docs/modules/node_parser.md b/apps/docs/docs/modules/node_parser.md index c5679fafa..9bf3fc304 100644 --- a/apps/docs/docs/modules/node_parser.md +++ b/apps/docs/docs/modules/node_parser.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # NodeParser diff --git a/apps/docs/docs/modules/node_postprocessors/_category_.yml b/apps/docs/docs/modules/node_postprocessors/_category_.yml new file mode 100644 index 000000000..88891a4dc --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/_category_.yml @@ -0,0 +1,2 @@ +label: "Node Postprocessors" +position: 3 diff --git a/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md b/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md new file mode 100644 index 000000000..717f8a2cd --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/cohere_reranker.md @@ -0,0 +1,71 @@ +# Cohere Reranker + +The Cohere Reranker is a postprocessor that uses the Cohere API to rerank the results of a search query. + +## Setup + +Firstly, you will need to install the `llamaindex` package. + +```bash +pnpm install llamaindex +``` + +Now, you will need to sign up for an API key at [Cohere](https://cohere.ai/). Once you have your API key you can import the necessary modules and create a new instance of the `CohereRerank` class. + +```ts +import { + CohereRerank, + Document, + OpenAI, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; +``` + +## Load and index documents + +For this example, we will use a single document. In a real-world scenario, you would have multiple documents to index. + +```ts +const document = new Document({ text: essay, id_: "essay" }); + +const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), +}); + +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); +``` + +## Increase similarity topK to retrieve more results + +The default value for `similarityTopK` is 2. This means that only the most similar document will be returned. To retrieve more results, you can increase the value of `similarityTopK`. + +```ts +const retriever = index.asRetriever(); +retriever.similarityTopK = 5; +``` + +## Create a new instance of the CohereRerank class + +Then you can create a new instance of the `CohereRerank` class and pass in your API key and the number of results you want to return. + +```ts +const nodePostprocessor = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 4, +}); +``` + +## Create a query engine with the retriever and node postprocessor + +```ts +const queryEngine = index.asQueryEngine({ + retriever, + nodePostprocessors: [nodePostprocessor], +}); + +// log the response +const response = await queryEngine.query("Where did the author grown up?"); +``` diff --git a/apps/docs/docs/modules/node_postprocessors/index.md b/apps/docs/docs/modules/node_postprocessors/index.md new file mode 100644 index 000000000..ec5ae88e1 --- /dev/null +++ b/apps/docs/docs/modules/node_postprocessors/index.md @@ -0,0 +1,110 @@ +# Node Postprocessors + +## Concept + +Node postprocessors are a set of modules that take a set of nodes, and apply some kind of transformation or filtering before returning them. + +In LlamaIndex, node postprocessors are most commonly applied within a query engine, after the node retrieval step and before the response synthesis step. + +LlamaIndex offers several node postprocessors for immediate use, while also providing a simple API for adding your own custom postprocessors. + +## Usage Pattern + +An example of using a node postprocessors is below: + +```ts +import { + Node, + NodeWithScore, + SimilarityPostprocessor, + CohereRerank, +} from "llamaindex"; + +const nodes: NodeWithScore[] = [ + { + node: new TextNode({ text: "hello world" }), + score: 0.8, + }, + { + node: new TextNode({ text: "LlamaIndex is the best" }), + score: 0.6, + }, +]; + +// similarity postprocessor: filter nodes below 0.75 similarity score +const processor = new SimilarityPostprocessor({ + similarityCutoff: 0.7, +}); + +const filteredNodes = processor.postprocessNodes(nodes); + +// cohere rerank: rerank nodes given query using trained model +const reranker = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 2, +}); + +const rerankedNodes = await reranker.postprocessNodes(nodes, "<user_query>"); + +console.log(filteredNodes, rerankedNodes); +``` + +Now you can use the `filteredNodes` and `rerankedNodes` in your application. + +## Using Node Postprocessors in LlamaIndex + +Most commonly, node-postprocessors will be used in a query engine, where they are applied to the nodes returned from a retriever, and before the response synthesis step. + +### Using Node Postprocessors in a Query Engine + +```ts +import { Node, NodeWithScore, SimilarityPostprocessor, CohereRerank } from "llamaindex"; + +const nodes: NodeWithScore[] = [ + { + node: new TextNode({ text: "hello world" }), + score: 0.8, + }, + { + node: new TextNode({ text: "LlamaIndex is the best" }), + score: 0.6, + } +]; + +// cohere rerank: rerank nodes given query using trained model +const reranker = new CohereRerank({ + apiKey: "<COHERE_API_KEY>, + topN: 2, +}) + +const document = new Document({ text: "essay", id_: "essay" }); + +const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), +}); + +const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, +}); + +const queryEngine = index.asQueryEngine({ + nodePostprocessors: [processor, reranker], +}); + +// all node post-processors will be applied during each query +const response = await queryEngine.query("<user_query>"); +``` + +### Using with retrieved nodes + +```ts +import { SimilarityPostprocessor } from "llamaindex"; + +nodes = await index.asRetriever().retrieve("test query str"); + +const processor = new SimilarityPostprocessor({ + similarityCutoff: 0.7, +}); + +const filteredNodes = processor.postprocessNodes(nodes); +``` diff --git a/examples/rerankers/CohereReranker.ts b/examples/rerankers/CohereReranker.ts new file mode 100644 index 000000000..4b34fa549 --- /dev/null +++ b/examples/rerankers/CohereReranker.ts @@ -0,0 +1,55 @@ +import { + CohereRerank, + Document, + OpenAI, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +import essay from "../essay"; + +async function main() { + const document = new Document({ text: essay, id_: "essay" }); + + const serviceContext = serviceContextFromDefaults({ + llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }), + }); + + const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, + }); + + const retriever = index.asRetriever(); + + retriever.similarityTopK = 5; + + const nodePostprocessor = new CohereRerank({ + apiKey: "<COHERE_API_KEY>", + topN: 5, + }); + + const queryEngine = index.asQueryEngine({ + retriever, + nodePostprocessors: [nodePostprocessor], + }); + + const baseQueryEngine = index.asQueryEngine({ + retriever, + }); + + const response = await queryEngine.query({ + query: "What did the author do growing up?", + }); + + // cohere response + console.log(response.response); + + const baseResponse = await baseQueryEngine.query({ + query: "What did the author do growing up?", + }); + + // response without cohere + console.log(baseResponse.response); +} + +main().catch(console.error); diff --git a/packages/core/package.json b/packages/core/package.json index 9e5719e6e..ff4032902 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -13,6 +13,7 @@ "@xenova/transformers": "^2.15.0", "assemblyai": "^4.2.2", "chromadb": "~1.7.3", + "cohere-ai": "^7.7.5", "file-type": "^18.7.0", "js-tiktoken": "^1.0.10", "lodash": "^4.17.21", diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index b5db4d27b..11566b85d 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -22,11 +22,17 @@ export class DefaultContextGenerator implements ContextGenerator { this.nodePostprocessors = init.nodePostprocessors || []; } - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + let nodesWithScore = nodes; + + for (const postprocessor of this.nodePostprocessors) { + nodesWithScore = await postprocessor.postprocessNodes( + nodesWithScore, + query, + ); + } + + return nodesWithScore; } async generate(message: string, parentEvent?: Event): Promise<Context> { @@ -42,7 +48,10 @@ export class DefaultContextGenerator implements ContextGenerator { parentEvent, ); - const nodes = this.applyNodePostprocessors(sourceNodesWithScore); + const nodes = await this.applyNodePostprocessors( + sourceNodesWithScore, + message, + ); return { message: { diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index 51470a387..cbec0f7e3 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -36,11 +36,17 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.nodePostprocessors = nodePostprocessors || []; } - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); + private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + let nodesWithScore = nodes; + + for (const postprocessor of this.nodePostprocessors) { + nodesWithScore = await postprocessor.postprocessNodes( + nodesWithScore, + query, + ); + } + + return nodesWithScore; } private async retrieve(query: string, parentEvent: Event) { @@ -50,7 +56,7 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.preFilters, ); - return this.applyNodePostprocessors(nodes); + return await this.applyNodePostprocessors(nodes, query); } query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; diff --git a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts index d05d94a10..e53a6e435 100644 --- a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts +++ b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts @@ -8,7 +8,7 @@ export class MetadataReplacementPostProcessor implements BaseNodePostprocessor { this.targetMetadataKey = targetMetadataKey; } - postprocessNodes(nodes: NodeWithScore[]): NodeWithScore[] { + async postprocessNodes(nodes: NodeWithScore[]): Promise<NodeWithScore[]> { for (let n of nodes) { n.node.setContent( n.node.metadata[this.targetMetadataKey] ?? diff --git a/packages/core/src/postprocessors/SimilarityPostprocessor.ts b/packages/core/src/postprocessors/SimilarityPostprocessor.ts index 91674515e..413fa9229 100644 --- a/packages/core/src/postprocessors/SimilarityPostprocessor.ts +++ b/packages/core/src/postprocessors/SimilarityPostprocessor.ts @@ -8,7 +8,7 @@ export class SimilarityPostprocessor implements BaseNodePostprocessor { this.similarityCutoff = options?.similarityCutoff; } - postprocessNodes(nodes: NodeWithScore[]) { + async postprocessNodes(nodes: NodeWithScore[]) { if (this.similarityCutoff === undefined) return nodes; const cutoff = this.similarityCutoff || 0; diff --git a/packages/core/src/postprocessors/index.ts b/packages/core/src/postprocessors/index.ts index f79e4ced0..3e52ee40c 100644 --- a/packages/core/src/postprocessors/index.ts +++ b/packages/core/src/postprocessors/index.ts @@ -1,3 +1,4 @@ export * from "./MetadataReplacementPostProcessor"; export * from "./SimilarityPostprocessor"; +export * from "./rerankers"; export * from "./types"; diff --git a/packages/core/src/postprocessors/rerankers/CohereRerank.ts b/packages/core/src/postprocessors/rerankers/CohereRerank.ts new file mode 100644 index 000000000..44d4e9f7a --- /dev/null +++ b/packages/core/src/postprocessors/rerankers/CohereRerank.ts @@ -0,0 +1,82 @@ +import { CohereClient } from "cohere-ai"; + +import { MetadataMode, NodeWithScore } from "../../Node"; +import { BaseNodePostprocessor } from "../types"; + +type CohereRerankOptions = { + topN?: number; + model?: string; + apiKey: string | null; +}; + +export class CohereRerank implements BaseNodePostprocessor { + topN: number = 2; + model: string = "rerank-english-v2.0"; + apiKey: string | null = null; + + private client: CohereClient | null = null; + + /** + * Constructor for CohereRerank. + * @param topN Number of nodes to return. + */ + constructor({ + topN = 2, + model = "rerank-english-v2.0", + apiKey = null, + }: CohereRerankOptions) { + if (apiKey === null) { + throw new Error("CohereRerank requires an API key"); + } + + this.topN = topN; + this.model = model; + this.apiKey = apiKey; + + this.client = new CohereClient({ + token: this.apiKey, + }); + } + + /** + * Reranks the nodes using the Cohere API. + * @param nodes Array of nodes with scores. + * @param query Query string. + */ + async postprocessNodes( + nodes: NodeWithScore[], + query?: string, + ): Promise<NodeWithScore[]> { + if (this.client === null) { + throw new Error("CohereRerank client is null"); + } + + if (nodes.length === 0) { + return []; + } + + if (query === undefined) { + throw new Error("CohereRerank requires a query"); + } + + const results = await this.client.rerank({ + query, + model: this.model, + topN: this.topN, + documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), + }); + + const newNodes: NodeWithScore[] = []; + + for (const result of results.results) { + const node = nodes[result.index]; + + newNodes.push({ + node: node.node, + score: result.relevanceScore, + }); + } + + return newNodes; + } +} diff --git a/packages/core/src/postprocessors/rerankers/index.ts b/packages/core/src/postprocessors/rerankers/index.ts new file mode 100644 index 000000000..8ef81bdb8 --- /dev/null +++ b/packages/core/src/postprocessors/rerankers/index.ts @@ -0,0 +1 @@ +export * from "./CohereRerank"; diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts index 2d0c73e78..deedc0cee 100644 --- a/packages/core/src/postprocessors/types.ts +++ b/packages/core/src/postprocessors/types.ts @@ -1,5 +1,14 @@ import { NodeWithScore } from "../Node"; export interface BaseNodePostprocessor { - postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[]; + /** + * Send message along with the class's current chat history to the LLM. + * This version returns a promise for asynchronous operation. + * @param nodes Array of nodes with scores. + * @param query Optional query string. + */ + postprocessNodes( + nodes: NodeWithScore[], + query?: string, + ): Promise<NodeWithScore[]>; } diff --git a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts index f9a76845d..1e6f16c9d 100644 --- a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts +++ b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts @@ -18,15 +18,15 @@ describe("MetadataReplacementPostProcessor", () => { ]; }); - test("Replaces the content of each node with specified metadata key if it exists", () => { + test("Replaces the content of each node with specified metadata key if it exists", async () => { nodes[0].node.metadata = { targetKey: "NewContent" }; - const newNodes = postProcessor.postprocessNodes(nodes); + const newNodes = await postProcessor.postprocessNodes(nodes); // Check if node content was replaced correctly expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("NewContent"); }); - test("Retains the original content of each node if no metadata key is found", () => { - const newNodes = postProcessor.postprocessNodes(nodes); + test("Retains the original content of each node if no metadata key is found", async () => { + const newNodes = await postProcessor.postprocessNodes(nodes); // Check if node content remained unchanged expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("OldContent"); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8aa514fc7..e27fc0c2e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -194,7 +194,10 @@ importers: version: 4.2.2 chromadb: specifier: ~1.7.3 - version: 1.7.3(openai@4.26.1) + version: 1.7.3(cohere-ai@7.7.5)(openai@4.26.1) + cohere-ai: + specifier: ^7.7.5 + version: 7.7.5 file-type: specifier: ^18.7.0 version: 18.7.0 @@ -5966,7 +5969,7 @@ packages: engines: {node: '>=10'} dev: true - /chromadb@1.7.3(openai@4.26.1): + /chromadb@1.7.3(cohere-ai@7.7.5)(openai@4.26.1): resolution: {integrity: sha512-3GgvQjpqgk5C89x5EuTDaXKbfrdqYDJ5UVyLQ3ZmwxnpetNc+HhRDGjkvXa5KSvpQ3lmKoyDoqnN4tZepfFkbw==} engines: {node: '>=14.17.0'} peerDependencies: @@ -5982,6 +5985,7 @@ packages: optional: true dependencies: cliui: 8.0.1 + cohere-ai: 7.7.5 isomorphic-fetch: 3.0.0 openai: 4.26.1 transitivePeerDependencies: @@ -6130,6 +6134,18 @@ packages: rfdc: 1.3.1 dev: false + /cohere-ai@7.7.5: + resolution: {integrity: sha512-uKh4TzHpY/8nwuYprLKSz0vSKclL4zb8il/YNxfFuWgk9/Nhuw4ugzRuT+CE2f4miQG5dPtPfwXCYUJB60mfwQ==} + dependencies: + form-data: 4.0.0 + js-base64: 3.7.2 + node-fetch: 2.7.0(encoding@0.1.13) + qs: 6.11.2 + url-join: 4.0.1 + transitivePeerDependencies: + - encoding + dev: false + /collapse-white-space@2.1.0: resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==} @@ -10276,6 +10292,10 @@ packages: '@sideway/formula': 3.0.1 '@sideway/pinpoint': 2.0.0 + /js-base64@3.7.2: + resolution: {integrity: sha512-NnRs6dsyqUXejqk/yv2aiXlAvOs56sLkX6nUdeaNezI5LFFLlsZjOThmwnrcwh5ZZRwZlCMnVAY3CvhIhoVEKQ==} + dev: false + /js-tiktoken@1.0.10: resolution: {integrity: sha512-ZoSxbGjvGyMT13x6ACo9ebhDha/0FHdKA+OsQcMOWcm1Zs7r90Rhk5lhERLzji+3rA7EKpXCgwXcM5fF3DMpdA==} dependencies: @@ -13311,6 +13331,13 @@ packages: dependencies: side-channel: 1.0.4 + /qs@6.11.2: + resolution: {integrity: sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==} + engines: {node: '>=0.6'} + dependencies: + side-channel: 1.0.4 + dev: false + /queue-microtask@1.2.3: resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} @@ -15824,6 +15851,10 @@ packages: dependencies: punycode: 2.3.1 + /url-join@4.0.1: + resolution: {integrity: sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==} + dev: false + /url-loader@4.1.1(file-loader@6.2.0)(webpack@5.90.0): resolution: {integrity: sha512-3BTV812+AVHHOJQO8O5MkWgZ5aosP7GnROJwvzLS9hWDj00lZ6Z0wNak423Lp9PBZN05N+Jk/N5Si8jRAlGyWA==} engines: {node: '>= 10.13.0'} -- GitLab