Skip to content
Snippets Groups Projects
Commit f9d1a6e0 authored by Yi Ding's avatar Yi Ding
Browse files

add top P

parent b18e1228
Branches anthropic
No related tags found
No related merge requests found
---
"llamaindex": patch
---
Add Top P
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
"name": "simple", "name": "simple",
"version": "0.0.3", "version": "0.0.3",
"dependencies": { "dependencies": {
"llamaindex": "0.0.0-20230730023617" "llamaindex": "^0.0.0-20230730023617"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^18" "@types/node": "^18"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"private": true, "private": true,
"name": "simple", "name": "simple",
"dependencies": { "dependencies": {
"llamaindex": "0.0.0-20230730023617" "llamaindex": "latest"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^18" "@types/node": "^18"
......
...@@ -73,6 +73,7 @@ export class OpenAI implements LLM { ...@@ -73,6 +73,7 @@ export class OpenAI implements LLM {
// Per completion OpenAI params // Per completion OpenAI params
model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS; model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS;
temperature: number; temperature: number;
topP: number;
maxTokens?: number; maxTokens?: number;
// OpenAI session params // OpenAI session params
...@@ -86,6 +87,7 @@ export class OpenAI implements LLM { ...@@ -86,6 +87,7 @@ export class OpenAI implements LLM {
constructor(init?: Partial<OpenAI>) { constructor(init?: Partial<OpenAI>) {
this.model = init?.model ?? "gpt-3.5-turbo"; this.model = init?.model ?? "gpt-3.5-turbo";
this.temperature = init?.temperature ?? 0; this.temperature = init?.temperature ?? 0;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined; this.maxTokens = init?.maxTokens ?? undefined;
this.apiKey = init?.apiKey ?? undefined; this.apiKey = init?.apiKey ?? undefined;
...@@ -131,6 +133,7 @@ export class OpenAI implements LLM { ...@@ -131,6 +133,7 @@ export class OpenAI implements LLM {
role: this.mapMessageType(message.role), role: this.mapMessageType(message.role),
content: message.content, content: message.content,
})), })),
top_p: this.topP,
}; };
if (this.callbackManager?.onLLMStream) { if (this.callbackManager?.onLLMStream) {
...@@ -198,14 +201,16 @@ export class LlamaDeuce implements LLM { ...@@ -198,14 +201,16 @@ export class LlamaDeuce implements LLM {
model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS;
chatStrategy: DeuceChatStrategy; chatStrategy: DeuceChatStrategy;
temperature: number; temperature: number;
topP: number;
maxTokens?: number; maxTokens?: number;
replicateSession: ReplicateSession; replicateSession: ReplicateSession;
constructor(init?: Partial<LlamaDeuce>) { constructor(init?: Partial<LlamaDeuce>) {
this.model = init?.model ?? "Llama-2-70b-chat"; this.model = init?.model ?? "Llama-2-70b-chat";
this.chatStrategy = init?.chatStrategy ?? DeuceChatStrategy.META; this.chatStrategy = init?.chatStrategy ?? DeuceChatStrategy.META;
this.temperature = init?.temperature ?? 0; this.temperature = init?.temperature ?? 0.01; // minimum temperature is 0.01 for Replicate endpoint
this.maxTokens = init?.maxTokens ?? undefined; this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined; // By default this means it's 500 tokens according to Replicate docs
this.replicateSession = init?.replicateSession ?? new ReplicateSession(); this.replicateSession = init?.replicateSession ?? new ReplicateSession();
} }
...@@ -303,6 +308,10 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -303,6 +308,10 @@ If a question does not make any sense, or is not factually coherent, explain why
const response = await this.replicateSession.replicate.run(api, { const response = await this.replicateSession.replicate.run(api, {
input: { input: {
prompt, prompt,
system_prompt: "", // We are already sending the system prompt so set system prompt to empty.
max_new_tokens: this.maxTokens,
temperature: this.temperature,
top_p: this.topP,
}, },
}); });
return { return {
...@@ -330,6 +339,7 @@ export class Anthropic implements LLM { ...@@ -330,6 +339,7 @@ export class Anthropic implements LLM {
// Per completion Anthropic params // Per completion Anthropic params
model: string; model: string;
temperature: number; temperature: number;
topP: number;
maxTokens?: number; maxTokens?: number;
// Anthropic session params // Anthropic session params
...@@ -343,6 +353,7 @@ export class Anthropic implements LLM { ...@@ -343,6 +353,7 @@ export class Anthropic implements LLM {
constructor(init?: Partial<Anthropic>) { constructor(init?: Partial<Anthropic>) {
this.model = init?.model ?? "claude-2"; this.model = init?.model ?? "claude-2";
this.temperature = init?.temperature ?? 0; this.temperature = init?.temperature ?? 0;
this.topP = init?.topP ?? 0.999; // Per Ben Mann
this.maxTokens = init?.maxTokens ?? undefined; this.maxTokens = init?.maxTokens ?? undefined;
this.apiKey = init?.apiKey ?? undefined; this.apiKey = init?.apiKey ?? undefined;
...@@ -383,6 +394,7 @@ export class Anthropic implements LLM { ...@@ -383,6 +394,7 @@ export class Anthropic implements LLM {
prompt: this.mapMessagesToPrompt(messages), prompt: this.mapMessagesToPrompt(messages),
max_tokens_to_sample: this.maxTokens ?? 100000, max_tokens_to_sample: this.maxTokens ?? 100000,
temperature: this.temperature, temperature: this.temperature,
top_p: this.topP,
}); });
return { return {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment