Skip to content
Snippets Groups Projects
Unverified Commit e78e9f48 authored by Emanuel Ferreira's avatar Emanuel Ferreira Committed by GitHub
Browse files

feat(reranker): cohere reranker (#535)

parent 383933ad
No related branches found
No related tags found
No related merge requests found
Showing
with 407 additions and 24 deletions
---
"llamaindex": patch
---
feat(reranker): cohere reranker
---
sidebar_position: 3
sidebar_position: 4
---
import CodeBlock from "@theme/CodeBlock";
......
---
sidebar_position: 3
sidebar_position: 4
---
# Embedding
......
---
sidebar_position: 3
sidebar_position: 4
---
# NodeParser
......
label: "Node Postprocessors"
position: 3
# Cohere Reranker
The Cohere Reranker is a postprocessor that uses the Cohere API to rerank the results of a search query.
## Setup
Firstly, you will need to install the `llamaindex` package.
```bash
pnpm install llamaindex
```
Now, you will need to sign up for an API key at [Cohere](https://cohere.ai/). Once you have your API key you can import the necessary modules and create a new instance of the `CohereRerank` class.
```ts
import {
CohereRerank,
Document,
OpenAI,
VectorStoreIndex,
serviceContextFromDefaults,
} from "llamaindex";
```
## Load and index documents
For this example, we will use a single document. In a real-world scenario, you would have multiple documents to index.
```ts
const document = new Document({ text: essay, id_: "essay" });
const serviceContext = serviceContextFromDefaults({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
});
const index = await VectorStoreIndex.fromDocuments([document], {
serviceContext,
});
```
## Increase similarity topK to retrieve more results
The default value for `similarityTopK` is 2. This means that only the most similar document will be returned. To retrieve more results, you can increase the value of `similarityTopK`.
```ts
const retriever = index.asRetriever();
retriever.similarityTopK = 5;
```
## Create a new instance of the CohereRerank class
Then you can create a new instance of the `CohereRerank` class and pass in your API key and the number of results you want to return.
```ts
const nodePostprocessor = new CohereRerank({
apiKey: "<COHERE_API_KEY>",
topN: 4,
});
```
## Create a query engine with the retriever and node postprocessor
```ts
const queryEngine = index.asQueryEngine({
retriever,
nodePostprocessors: [nodePostprocessor],
});
// log the response
const response = await queryEngine.query("Where did the author grown up?");
```
# Node Postprocessors
## Concept
Node postprocessors are a set of modules that take a set of nodes, and apply some kind of transformation or filtering before returning them.
In LlamaIndex, node postprocessors are most commonly applied within a query engine, after the node retrieval step and before the response synthesis step.
LlamaIndex offers several node postprocessors for immediate use, while also providing a simple API for adding your own custom postprocessors.
## Usage Pattern
An example of using a node postprocessors is below:
```ts
import {
Node,
NodeWithScore,
SimilarityPostprocessor,
CohereRerank,
} from "llamaindex";
const nodes: NodeWithScore[] = [
{
node: new TextNode({ text: "hello world" }),
score: 0.8,
},
{
node: new TextNode({ text: "LlamaIndex is the best" }),
score: 0.6,
},
];
// similarity postprocessor: filter nodes below 0.75 similarity score
const processor = new SimilarityPostprocessor({
similarityCutoff: 0.7,
});
const filteredNodes = processor.postprocessNodes(nodes);
// cohere rerank: rerank nodes given query using trained model
const reranker = new CohereRerank({
apiKey: "<COHERE_API_KEY>",
topN: 2,
});
const rerankedNodes = await reranker.postprocessNodes(nodes, "<user_query>");
console.log(filteredNodes, rerankedNodes);
```
Now you can use the `filteredNodes` and `rerankedNodes` in your application.
## Using Node Postprocessors in LlamaIndex
Most commonly, node-postprocessors will be used in a query engine, where they are applied to the nodes returned from a retriever, and before the response synthesis step.
### Using Node Postprocessors in a Query Engine
```ts
import { Node, NodeWithScore, SimilarityPostprocessor, CohereRerank } from "llamaindex";
const nodes: NodeWithScore[] = [
{
node: new TextNode({ text: "hello world" }),
score: 0.8,
},
{
node: new TextNode({ text: "LlamaIndex is the best" }),
score: 0.6,
}
];
// cohere rerank: rerank nodes given query using trained model
const reranker = new CohereRerank({
apiKey: "<COHERE_API_KEY>,
topN: 2,
})
const document = new Document({ text: "essay", id_: "essay" });
const serviceContext = serviceContextFromDefaults({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
});
const index = await VectorStoreIndex.fromDocuments([document], {
serviceContext,
});
const queryEngine = index.asQueryEngine({
nodePostprocessors: [processor, reranker],
});
// all node post-processors will be applied during each query
const response = await queryEngine.query("<user_query>");
```
### Using with retrieved nodes
```ts
import { SimilarityPostprocessor } from "llamaindex";
nodes = await index.asRetriever().retrieve("test query str");
const processor = new SimilarityPostprocessor({
similarityCutoff: 0.7,
});
const filteredNodes = processor.postprocessNodes(nodes);
```
import {
CohereRerank,
Document,
OpenAI,
VectorStoreIndex,
serviceContextFromDefaults,
} from "llamaindex";
import essay from "../essay";
async function main() {
const document = new Document({ text: essay, id_: "essay" });
const serviceContext = serviceContextFromDefaults({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
});
const index = await VectorStoreIndex.fromDocuments([document], {
serviceContext,
});
const retriever = index.asRetriever();
retriever.similarityTopK = 5;
const nodePostprocessor = new CohereRerank({
apiKey: "<COHERE_API_KEY>",
topN: 5,
});
const queryEngine = index.asQueryEngine({
retriever,
nodePostprocessors: [nodePostprocessor],
});
const baseQueryEngine = index.asQueryEngine({
retriever,
});
const response = await queryEngine.query({
query: "What did the author do growing up?",
});
// cohere response
console.log(response.response);
const baseResponse = await baseQueryEngine.query({
query: "What did the author do growing up?",
});
// response without cohere
console.log(baseResponse.response);
}
main().catch(console.error);
......@@ -13,6 +13,7 @@
"@xenova/transformers": "^2.15.0",
"assemblyai": "^4.2.2",
"chromadb": "~1.7.3",
"cohere-ai": "^7.7.5",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.10",
"lodash": "^4.17.21",
......
......@@ -22,11 +22,17 @@ export class DefaultContextGenerator implements ContextGenerator {
this.nodePostprocessors = init.nodePostprocessors || [];
}
private applyNodePostprocessors(nodes: NodeWithScore[]) {
return this.nodePostprocessors.reduce(
(nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes),
nodes,
);
private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
let nodesWithScore = nodes;
for (const postprocessor of this.nodePostprocessors) {
nodesWithScore = await postprocessor.postprocessNodes(
nodesWithScore,
query,
);
}
return nodesWithScore;
}
async generate(message: string, parentEvent?: Event): Promise<Context> {
......@@ -42,7 +48,10 @@ export class DefaultContextGenerator implements ContextGenerator {
parentEvent,
);
const nodes = this.applyNodePostprocessors(sourceNodesWithScore);
const nodes = await this.applyNodePostprocessors(
sourceNodesWithScore,
message,
);
return {
message: {
......
......@@ -36,11 +36,17 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
this.nodePostprocessors = nodePostprocessors || [];
}
private applyNodePostprocessors(nodes: NodeWithScore[]) {
return this.nodePostprocessors.reduce(
(nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes),
nodes,
);
private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
let nodesWithScore = nodes;
for (const postprocessor of this.nodePostprocessors) {
nodesWithScore = await postprocessor.postprocessNodes(
nodesWithScore,
query,
);
}
return nodesWithScore;
}
private async retrieve(query: string, parentEvent: Event) {
......@@ -50,7 +56,7 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
this.preFilters,
);
return this.applyNodePostprocessors(nodes);
return await this.applyNodePostprocessors(nodes, query);
}
query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>;
......
......@@ -8,7 +8,7 @@ export class MetadataReplacementPostProcessor implements BaseNodePostprocessor {
this.targetMetadataKey = targetMetadataKey;
}
postprocessNodes(nodes: NodeWithScore[]): NodeWithScore[] {
async postprocessNodes(nodes: NodeWithScore[]): Promise<NodeWithScore[]> {
for (let n of nodes) {
n.node.setContent(
n.node.metadata[this.targetMetadataKey] ??
......
......@@ -8,7 +8,7 @@ export class SimilarityPostprocessor implements BaseNodePostprocessor {
this.similarityCutoff = options?.similarityCutoff;
}
postprocessNodes(nodes: NodeWithScore[]) {
async postprocessNodes(nodes: NodeWithScore[]) {
if (this.similarityCutoff === undefined) return nodes;
const cutoff = this.similarityCutoff || 0;
......
export * from "./MetadataReplacementPostProcessor";
export * from "./SimilarityPostprocessor";
export * from "./rerankers";
export * from "./types";
import { CohereClient } from "cohere-ai";
import { MetadataMode, NodeWithScore } from "../../Node";
import { BaseNodePostprocessor } from "../types";
type CohereRerankOptions = {
topN?: number;
model?: string;
apiKey: string | null;
};
export class CohereRerank implements BaseNodePostprocessor {
topN: number = 2;
model: string = "rerank-english-v2.0";
apiKey: string | null = null;
private client: CohereClient | null = null;
/**
* Constructor for CohereRerank.
* @param topN Number of nodes to return.
*/
constructor({
topN = 2,
model = "rerank-english-v2.0",
apiKey = null,
}: CohereRerankOptions) {
if (apiKey === null) {
throw new Error("CohereRerank requires an API key");
}
this.topN = topN;
this.model = model;
this.apiKey = apiKey;
this.client = new CohereClient({
token: this.apiKey,
});
}
/**
* Reranks the nodes using the Cohere API.
* @param nodes Array of nodes with scores.
* @param query Query string.
*/
async postprocessNodes(
nodes: NodeWithScore[],
query?: string,
): Promise<NodeWithScore[]> {
if (this.client === null) {
throw new Error("CohereRerank client is null");
}
if (nodes.length === 0) {
return [];
}
if (query === undefined) {
throw new Error("CohereRerank requires a query");
}
const results = await this.client.rerank({
query,
model: this.model,
topN: this.topN,
documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)),
});
const newNodes: NodeWithScore[] = [];
for (const result of results.results) {
const node = nodes[result.index];
newNodes.push({
node: node.node,
score: result.relevanceScore,
});
}
return newNodes;
}
}
export * from "./CohereRerank";
import { NodeWithScore } from "../Node";
export interface BaseNodePostprocessor {
postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[];
/**
* Send message along with the class's current chat history to the LLM.
* This version returns a promise for asynchronous operation.
* @param nodes Array of nodes with scores.
* @param query Optional query string.
*/
postprocessNodes(
nodes: NodeWithScore[],
query?: string,
): Promise<NodeWithScore[]>;
}
......@@ -18,15 +18,15 @@ describe("MetadataReplacementPostProcessor", () => {
];
});
test("Replaces the content of each node with specified metadata key if it exists", () => {
test("Replaces the content of each node with specified metadata key if it exists", async () => {
nodes[0].node.metadata = { targetKey: "NewContent" };
const newNodes = postProcessor.postprocessNodes(nodes);
const newNodes = await postProcessor.postprocessNodes(nodes);
// Check if node content was replaced correctly
expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("NewContent");
});
test("Retains the original content of each node if no metadata key is found", () => {
const newNodes = postProcessor.postprocessNodes(nodes);
test("Retains the original content of each node if no metadata key is found", async () => {
const newNodes = await postProcessor.postprocessNodes(nodes);
// Check if node content remained unchanged
expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("OldContent");
});
......
......@@ -194,7 +194,10 @@ importers:
version: 4.2.2
chromadb:
specifier: ~1.7.3
version: 1.7.3(openai@4.26.1)
version: 1.7.3(cohere-ai@7.7.5)(openai@4.26.1)
cohere-ai:
specifier: ^7.7.5
version: 7.7.5
file-type:
specifier: ^18.7.0
version: 18.7.0
......@@ -5966,7 +5969,7 @@ packages:
engines: {node: '>=10'}
dev: true
 
/chromadb@1.7.3(openai@4.26.1):
/chromadb@1.7.3(cohere-ai@7.7.5)(openai@4.26.1):
resolution: {integrity: sha512-3GgvQjpqgk5C89x5EuTDaXKbfrdqYDJ5UVyLQ3ZmwxnpetNc+HhRDGjkvXa5KSvpQ3lmKoyDoqnN4tZepfFkbw==}
engines: {node: '>=14.17.0'}
peerDependencies:
......@@ -5982,6 +5985,7 @@ packages:
optional: true
dependencies:
cliui: 8.0.1
cohere-ai: 7.7.5
isomorphic-fetch: 3.0.0
openai: 4.26.1
transitivePeerDependencies:
......@@ -6130,6 +6134,18 @@ packages:
rfdc: 1.3.1
dev: false
 
/cohere-ai@7.7.5:
resolution: {integrity: sha512-uKh4TzHpY/8nwuYprLKSz0vSKclL4zb8il/YNxfFuWgk9/Nhuw4ugzRuT+CE2f4miQG5dPtPfwXCYUJB60mfwQ==}
dependencies:
form-data: 4.0.0
js-base64: 3.7.2
node-fetch: 2.7.0(encoding@0.1.13)
qs: 6.11.2
url-join: 4.0.1
transitivePeerDependencies:
- encoding
dev: false
/collapse-white-space@2.1.0:
resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==}
 
......@@ -10276,6 +10292,10 @@ packages:
'@sideway/formula': 3.0.1
'@sideway/pinpoint': 2.0.0
 
/js-base64@3.7.2:
resolution: {integrity: sha512-NnRs6dsyqUXejqk/yv2aiXlAvOs56sLkX6nUdeaNezI5LFFLlsZjOThmwnrcwh5ZZRwZlCMnVAY3CvhIhoVEKQ==}
dev: false
/js-tiktoken@1.0.10:
resolution: {integrity: sha512-ZoSxbGjvGyMT13x6ACo9ebhDha/0FHdKA+OsQcMOWcm1Zs7r90Rhk5lhERLzji+3rA7EKpXCgwXcM5fF3DMpdA==}
dependencies:
......@@ -13311,6 +13331,13 @@ packages:
dependencies:
side-channel: 1.0.4
 
/qs@6.11.2:
resolution: {integrity: sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==}
engines: {node: '>=0.6'}
dependencies:
side-channel: 1.0.4
dev: false
/queue-microtask@1.2.3:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
 
......@@ -15824,6 +15851,10 @@ packages:
dependencies:
punycode: 2.3.1
 
/url-join@4.0.1:
resolution: {integrity: sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==}
dev: false
/url-loader@4.1.1(file-loader@6.2.0)(webpack@5.90.0):
resolution: {integrity: sha512-3BTV812+AVHHOJQO8O5MkWgZ5aosP7GnROJwvzLS9hWDj00lZ6Z0wNak423Lp9PBZN05N+Jk/N5Si8jRAlGyWA==}
engines: {node: '>= 10.13.0'}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment