From a6db5dd29b31d713bf080bb6cd387dc7817e7e9c Mon Sep 17 00:00:00 2001
From: Rozstone <42225395+wststone@users.noreply.github.com>
Date: Fri, 8 Nov 2024 07:13:13 +0800
Subject: [PATCH] feat: add baseUrl and timeout option in cohere rerank (#1445)

Co-authored-by: Alex Yang <himself65@outlook.com>
---
 .changeset/cuddly-bugs-train.md               |  5 +++
 .../postprocessors/rerankers/CohereRerank.ts  | 38 ++++++++++++++-----
 2 files changed, 34 insertions(+), 9 deletions(-)
 create mode 100644 .changeset/cuddly-bugs-train.md

diff --git a/.changeset/cuddly-bugs-train.md b/.changeset/cuddly-bugs-train.md
new file mode 100644
index 000000000..7eec12bef
--- /dev/null
+++ b/.changeset/cuddly-bugs-train.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat: add baseUrl and timeout option in cohere rerank
diff --git a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts
index c8310bad0..e9f2889ce 100644
--- a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts
+++ b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts
@@ -10,12 +10,16 @@ type CohereRerankOptions = {
   topN?: number;
   model?: string;
   apiKey: string | null;
+  baseUrl?: string;
+  timeout?: number;
 };
 
 export class CohereRerank implements BaseNodePostprocessor {
   topN: number = 2;
   model: string = "rerank-english-v2.0";
   apiKey: string | null = null;
+  baseUrl: string | undefined;
+  timeout: number | undefined;
 
   private client: CohereClient | null = null;
 
@@ -27,6 +31,8 @@ export class CohereRerank implements BaseNodePostprocessor {
     topN = 2,
     model = "rerank-english-v2.0",
     apiKey = null,
+    baseUrl,
+    timeout,
   }: CohereRerankOptions) {
     if (apiKey === null) {
       throw new Error("CohereRerank requires an API key");
@@ -35,10 +41,19 @@ export class CohereRerank implements BaseNodePostprocessor {
     this.topN = topN;
     this.model = model;
     this.apiKey = apiKey;
+    this.baseUrl = baseUrl;
+    this.timeout = timeout;
 
-    this.client = new CohereClient({
-      token: this.apiKey,
-    });
+    this.client = new CohereClient(
+      this.baseUrl
+        ? {
+            token: this.apiKey,
+            environment: this.baseUrl,
+          }
+        : {
+            token: this.apiKey,
+          },
+    );
   }
 
   /**
@@ -62,12 +77,17 @@ export class CohereRerank implements BaseNodePostprocessor {
       throw new Error("CohereRerank requires a query");
     }
 
-    const results = await this.client.rerank({
-      query: extractText(query),
-      model: this.model,
-      topN: this.topN,
-      documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)),
-    });
+    const results = await this.client.rerank(
+      {
+        query: extractText(query),
+        model: this.model,
+        topN: this.topN,
+        documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)),
+      },
+      this.timeout !== undefined
+        ? { timeoutInSeconds: this.timeout }
+        : undefined,
+    );
 
     const newNodes: NodeWithScore[] = [];
 
-- 
GitLab