From f9d1a6e013ebf574c7da090d38d15243c20b894a Mon Sep 17 00:00:00 2001 From: Yi Ding <yi.s.ding@gmail.com> Date: Mon, 31 Jul 2023 00:27:09 -0700 Subject: [PATCH] add top P --- .changeset/spotty-planets-whisper.md | 5 +++++ examples/package-lock.json | 2 +- examples/package.json | 2 +- packages/core/src/llm/LLM.ts | 16 ++++++++++++++-- 4 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 .changeset/spotty-planets-whisper.md diff --git a/.changeset/spotty-planets-whisper.md b/.changeset/spotty-planets-whisper.md new file mode 100644 index 000000000..36bd7533d --- /dev/null +++ b/.changeset/spotty-planets-whisper.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add Top P diff --git a/examples/package-lock.json b/examples/package-lock.json index fb3651c3d..a633d7595 100644 --- a/examples/package-lock.json +++ b/examples/package-lock.json @@ -8,7 +8,7 @@ "name": "simple", "version": "0.0.3", "dependencies": { - "llamaindex": "0.0.0-20230730023617" + "llamaindex": "^0.0.0-20230730023617" }, "devDependencies": { "@types/node": "^18" diff --git a/examples/package.json b/examples/package.json index ef3b8b7b3..dd9a702c7 100644 --- a/examples/package.json +++ b/examples/package.json @@ -3,7 +3,7 @@ "private": true, "name": "simple", "dependencies": { - "llamaindex": "0.0.0-20230730023617" + "llamaindex": "latest" }, "devDependencies": { "@types/node": "^18" diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index d69453ee5..dbd239d6a 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -73,6 +73,7 @@ export class OpenAI implements LLM { // Per completion OpenAI params model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS; temperature: number; + topP: number; maxTokens?: number; // OpenAI session params @@ -86,6 +87,7 @@ export class OpenAI implements LLM { constructor(init?: Partial<OpenAI>) { this.model = init?.model ?? "gpt-3.5-turbo"; this.temperature = init?.temperature ?? 0; + this.topP = init?.topP ?? 1; this.maxTokens = init?.maxTokens ?? undefined; this.apiKey = init?.apiKey ?? undefined; @@ -131,6 +133,7 @@ export class OpenAI implements LLM { role: this.mapMessageType(message.role), content: message.content, })), + top_p: this.topP, }; if (this.callbackManager?.onLLMStream) { @@ -198,14 +201,16 @@ export class LlamaDeuce implements LLM { model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; chatStrategy: DeuceChatStrategy; temperature: number; + topP: number; maxTokens?: number; replicateSession: ReplicateSession; constructor(init?: Partial<LlamaDeuce>) { this.model = init?.model ?? "Llama-2-70b-chat"; this.chatStrategy = init?.chatStrategy ?? DeuceChatStrategy.META; - this.temperature = init?.temperature ?? 0; - this.maxTokens = init?.maxTokens ?? undefined; + this.temperature = init?.temperature ?? 0.01; // minimum temperature is 0.01 for Replicate endpoint + this.topP = init?.topP ?? 1; + this.maxTokens = init?.maxTokens ?? undefined; // By default this means it's 500 tokens according to Replicate docs this.replicateSession = init?.replicateSession ?? new ReplicateSession(); } @@ -303,6 +308,10 @@ If a question does not make any sense, or is not factually coherent, explain why const response = await this.replicateSession.replicate.run(api, { input: { prompt, + system_prompt: "", // We are already sending the system prompt so set system prompt to empty. + max_new_tokens: this.maxTokens, + temperature: this.temperature, + top_p: this.topP, }, }); return { @@ -330,6 +339,7 @@ export class Anthropic implements LLM { // Per completion Anthropic params model: string; temperature: number; + topP: number; maxTokens?: number; // Anthropic session params @@ -343,6 +353,7 @@ export class Anthropic implements LLM { constructor(init?: Partial<Anthropic>) { this.model = init?.model ?? "claude-2"; this.temperature = init?.temperature ?? 0; + this.topP = init?.topP ?? 0.999; // Per Ben Mann this.maxTokens = init?.maxTokens ?? undefined; this.apiKey = init?.apiKey ?? undefined; @@ -383,6 +394,7 @@ export class Anthropic implements LLM { prompt: this.mapMessagesToPrompt(messages), max_tokens_to_sample: this.maxTokens ?? 100000, temperature: this.temperature, + top_p: this.topP, }); return { -- GitLab