diff --git a/.changeset/cuddly-bugs-train.md b/.changeset/cuddly-bugs-train.md new file mode 100644 index 0000000000000000000000000000000000000000..7eec12bef345bdc4a6a0788a3b16517b035c172c --- /dev/null +++ b/.changeset/cuddly-bugs-train.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: add baseUrl and timeout option in cohere rerank diff --git a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts index c8310bad0fe76e4e09ef2241a5348dc034be3798..e9f2889ce7b280925aac8447363c67b004e4d67d 100644 --- a/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts +++ b/packages/llamaindex/src/postprocessors/rerankers/CohereRerank.ts @@ -10,12 +10,16 @@ type CohereRerankOptions = { topN?: number; model?: string; apiKey: string | null; + baseUrl?: string; + timeout?: number; }; export class CohereRerank implements BaseNodePostprocessor { topN: number = 2; model: string = "rerank-english-v2.0"; apiKey: string | null = null; + baseUrl: string | undefined; + timeout: number | undefined; private client: CohereClient | null = null; @@ -27,6 +31,8 @@ export class CohereRerank implements BaseNodePostprocessor { topN = 2, model = "rerank-english-v2.0", apiKey = null, + baseUrl, + timeout, }: CohereRerankOptions) { if (apiKey === null) { throw new Error("CohereRerank requires an API key"); @@ -35,10 +41,19 @@ export class CohereRerank implements BaseNodePostprocessor { this.topN = topN; this.model = model; this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.timeout = timeout; - this.client = new CohereClient({ - token: this.apiKey, - }); + this.client = new CohereClient( + this.baseUrl + ? { + token: this.apiKey, + environment: this.baseUrl, + } + : { + token: this.apiKey, + }, + ); } /** @@ -62,12 +77,17 @@ export class CohereRerank implements BaseNodePostprocessor { throw new Error("CohereRerank requires a query"); } - const results = await this.client.rerank({ - query: extractText(query), - model: this.model, - topN: this.topN, - documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), - }); + const results = await this.client.rerank( + { + query: extractText(query), + model: this.model, + topN: this.topN, + documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), + }, + this.timeout !== undefined + ? { timeoutInSeconds: this.timeout } + : undefined, + ); const newNodes: NodeWithScore[] = [];