From a6db5dd29b31d713bf080bb6cd387dc7817e7e9c Mon Sep 17 00:00:00 2001 From: Rozstone <42225395+wststone@users.noreply.github.com> Date: Fri, 8 Nov 2024 07:13:13 +0800 Subject: [PATCH] feat: add baseUrl and timeout option in cohere rerank (#1445) Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/cuddly-bugs-train.md | 5 +++ .../postprocessors/rerankers/CohereRerank.ts | 38 ++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 .changeset/cuddly-bugs-train.md diff --git a/.changeset/cuddly-bugs-train.md b/.changeset/cuddly-bugs-train.md new file mode 100644 index 000000000..7eec12bef --- /dev/null +++ b/.changeset/cuddly-bugs-train.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: add baseUrl and timeout option in cohere rerank diff --git a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts index c8310bad0..e9f2889ce 100644 --- a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts +++ b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts @@ -10,12 +10,16 @@ type CohereRerankOptions = { topN?: number; model?: string; apiKey: string | null; + baseUrl?: string; + timeout?: number; }; export class CohereRerank implements BaseNodePostprocessor { topN: number = 2; model: string = "rerank-english-v2.0"; apiKey: string | null = null; + baseUrl: string | undefined; + timeout: number | undefined; private client: CohereClient | null = null; @@ -27,6 +31,8 @@ export class CohereRerank implements BaseNodePostprocessor { topN = 2, model = "rerank-english-v2.0", apiKey = null, + baseUrl, + timeout, }: CohereRerankOptions) { if (apiKey === null) { throw new Error("CohereRerank requires an API key"); @@ -35,10 +41,19 @@ export class CohereRerank implements BaseNodePostprocessor { this.topN = topN; this.model = model; this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.timeout = timeout; - this.client = new CohereClient({ - token: this.apiKey, - }); + this.client = new CohereClient( + this.baseUrl + ? { + token: this.apiKey, + environment: this.baseUrl, + } + : { + token: this.apiKey, + }, + ); } /** @@ -62,12 +77,17 @@ export class CohereRerank implements BaseNodePostprocessor { throw new Error("CohereRerank requires a query"); } - const results = await this.client.rerank({ - query: extractText(query), - model: this.model, - topN: this.topN, - documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), - }); + const results = await this.client.rerank( + { + query: extractText(query), + model: this.model, + topN: this.topN, + documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), + }, + this.timeout !== undefined + ? { timeoutInSeconds: this.timeout } + : undefined, + ); const newNodes: NodeWithScore[] = []; -- GitLab