diff --git a/.changeset/lemon-schools-reflect.md b/.changeset/lemon-schools-reflect.md new file mode 100644 index 0000000000000000000000000000000000000000..9862d7c2a49d2cdce877489c65565636590c8244 --- /dev/null +++ b/.changeset/lemon-schools-reflect.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Added Meta strategy for Llama2 diff --git a/apps/simple/llamadeuce.ts b/apps/simple/llamadeuce.ts index d7b309481369318566507a0adf61ae90e2b5de46..373c0c0a2e6c5a14f60b2b53b610d8d0dca98d22 100644 --- a/apps/simple/llamadeuce.ts +++ b/apps/simple/llamadeuce.ts @@ -1,7 +1,7 @@ -import { LlamaDeuce } from "llamaindex"; +import { DeuceChatStrategy, LlamaDeuce } from "llamaindex"; (async () => { - const deuce = new LlamaDeuce(); + const deuce = new LlamaDeuce({ chatStrategy: DeuceChatStrategy.META }); const result = await deuce.chat([{ content: "Hello, world!", role: "user" }]); console.log(result); })(); diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index f88af987d0a49a6d30fd8779075305f9928343a9..c85c02b19307d0ee696c6e5c3f9fa5d3826bfa95 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -177,23 +177,56 @@ export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { }, }; +export enum DeuceChatStrategy { + A16Z = "a16z", + META = "meta", + METAWBOS = "metawbos", + //^ This is not exactly right because SentencePiece puts the BOS and EOS token IDs in after tokenization + // Unfortunately any string only API won't support these properly. +} + /** * Llama2 LLM implementation */ export class LlamaDeuce implements LLM { model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; + chatStrategy: DeuceChatStrategy; temperature: number; maxTokens?: number; replicateSession: ReplicateSession; constructor(init?: Partial<LlamaDeuce>) { this.model = init?.model ?? "Llama-2-70b-chat"; + this.chatStrategy = init?.chatStrategy ?? DeuceChatStrategy.META; this.temperature = init?.temperature ?? 0; this.maxTokens = init?.maxTokens ?? undefined; this.replicateSession = init?.replicateSession ?? new ReplicateSession(); } - mapMessageType(messageType: MessageType): string { + mapMessagesToPrompt(messages: ChatMessage[]): string { + if (this.chatStrategy === DeuceChatStrategy.A16Z) { + return this.mapMessagesToPromptA16Z(messages); + } else if (this.chatStrategy === DeuceChatStrategy.META) { + return this.mapMessagesToPromptMeta(messages); + } else if (this.chatStrategy === DeuceChatStrategy.METAWBOS) { + return this.mapMessagesToPromptMeta(messages, true); + } else { + return this.mapMessagesToPromptMeta(messages); + } + } + + mapMessagesToPromptA16Z(messages: ChatMessage[]): string { + return ( + messages.reduce((acc, message) => { + return ( + (acc && `${acc}\n\n`) + + `${this.mapMessageTypeA16Z(message.role)}${message.content}` + ); + }, "") + "\n\nAssistant:" + ); // Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization + } + + mapMessageTypeA16Z(messageType: MessageType): string { switch (messageType) { case "user": return "User: "; @@ -206,26 +239,70 @@ export class LlamaDeuce implements LLM { } } + mapMessagesToPromptMeta(messages: ChatMessage[], withBos = false): string { + const DEFAULT_SYSTEM_PROMPT = `You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.`; + + const B_SYS = "<<SYS>>\n"; + const E_SYS = "\n<</SYS>>\n\n"; + const B_INST = "[INST]"; + const E_INST = "[/INST]"; + const BOS = "<s>"; + const EOS = "</s>"; + + if (messages.length === 0) { + return ""; + } + + if (messages[0].role === "system") { + const systemMessage = messages.shift()!; + + const systemStr = `${B_SYS}${systemMessage.content}${E_SYS}`; + + if (messages[1].role !== "user") { + throw new Error( + "LlamaDeuce: if there is a system message, the second message must be a user message." + ); + } + + const userContent = messages[0].content; + + messages[0].content = `${systemStr}${userContent}`; + } else { + messages[0].content = `${B_SYS}${DEFAULT_SYSTEM_PROMPT}${E_SYS}${messages[0].content}`; + } + + return messages.reduce((acc, message, index) => { + if (index % 2 === 0) { + return ( + (withBos ? BOS : "") + + `${acc}${B_INST} ${message.content.trim()} ${E_INST}` + ); + } else { + return `${acc} ${message.content.trim()} ` + (withBos ? EOS : ""); // Yes, the EOS comes after the space. This is not a mistake. + } + }, ""); + } + async chat( messages: ChatMessage[], _parentEvent?: Event ): Promise<ChatResponse> { const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] .replicateApi as `${string}/${string}:${string}`; + + const prompt = this.mapMessagesToPrompt(messages); + const response = await this.replicateSession.replicate.run(api, { input: { - prompt: - messages.reduce((acc, message) => { - return ( - (acc && `${acc}\n\n`) + - `${this.mapMessageType(message.role)}${message.content}` - ); - }, "") + "\n\nAssistant:", // Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization + prompt, }, }); return { message: { - content: (response as Array<string>).join(""), // We need to do this because replicate returns a list of strings (for streaming functionality) + content: (response as Array<string>).join(""), + // We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function) role: "assistant", }, };