From 728b35e774d0c2742dcf132e0e25cfc9863837a0 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 15 Apr 2024 12:35:15 +0800 Subject: [PATCH] chore: remove `LLM.ts` (#720) --- packages/core/src/llm/LLM.ts | 373 -------------------------- packages/core/src/llm/index.ts | 8 +- packages/core/src/llm/portkey.ts | 110 +++++++- packages/core/src/llm/replicate_ai.ts | 279 ++++++++++++++++++- 4 files changed, 387 insertions(+), 383 deletions(-) delete mode 100644 packages/core/src/llm/LLM.ts diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts deleted file mode 100644 index dc37d3718..000000000 --- a/packages/core/src/llm/LLM.ts +++ /dev/null @@ -1,373 +0,0 @@ -import { type StreamCallbackResponse } from "../callbacks/CallbackManager.js"; - -import type { LLMOptions } from "portkey-ai"; -import { getCallbackManager } from "../internal/settings/CallbackManager.js"; -import { BaseLLM } from "./base.js"; -import type { PortkeySession } from "./portkey.js"; -import { getPortkeySession } from "./portkey.js"; -import { ReplicateSession } from "./replicate_ai.js"; -import type { - ChatMessage, - ChatResponse, - ChatResponseChunk, - LLMChatParamsNonStreaming, - LLMChatParamsStreaming, - LLMMetadata, - MessageType, -} from "./types.js"; -import { extractText, wrapLLMEvent } from "./utils.js"; - -export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { - "Llama-2-70b-chat-old": { - contextWindow: 4096, - replicateApi: - "replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", - //^ Previous 70b model. This is also actually 4 bit, although not exllama. - }, - "Llama-2-70b-chat-4bit": { - contextWindow: 4096, - replicateApi: - "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - //^ Model is based off of exllama 4bit. - }, - "Llama-2-13b-chat-old": { - contextWindow: 4096, - replicateApi: - "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", - }, - //^ Last known good 13b non-quantized model. In future versions they add the SYS and INST tags themselves - "Llama-2-13b-chat-4bit": { - contextWindow: 4096, - replicateApi: - "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", - }, - "Llama-2-7b-chat-old": { - contextWindow: 4096, - replicateApi: - "a16z-infra/llama7b-v2-chat:4f0a4744c7295c024a1de15e1a63c880d3da035fa1f49bfd344fe076074c8eea", - //^ Last (somewhat) known good 7b non-quantized model. In future versions they add the SYS and INST - // tags themselves - // https://github.com/replicate/cog-llama-template/commit/fa5ce83912cf82fc2b9c01a4e9dc9bff6f2ef137 - // Problem is that they fix the max_new_tokens issue in the same commit. :-( - }, - "Llama-2-7b-chat-4bit": { - contextWindow: 4096, - replicateApi: - "meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0", - }, -}; - -export enum DeuceChatStrategy { - A16Z = "a16z", - META = "meta", - METAWBOS = "metawbos", - //^ This is not exactly right because SentencePiece puts the BOS and EOS token IDs in after tokenization - // Unfortunately any string only API won't support these properly. - REPLICATE4BIT = "replicate4bit", - //^ To satisfy Replicate's 4 bit models' requirements where they also insert some INST tags - REPLICATE4BITWNEWLINES = "replicate4bitwnewlines", - //^ Replicate's documentation recommends using newlines: https://replicate.com/blog/how-to-prompt-llama -} - -/** - * Llama2 LLM implementation - */ -export class LlamaDeuce extends BaseLLM { - model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; - chatStrategy: DeuceChatStrategy; - temperature: number; - topP: number; - maxTokens?: number; - replicateSession: ReplicateSession; - - constructor(init?: Partial<LlamaDeuce>) { - super(); - this.model = init?.model ?? "Llama-2-70b-chat-4bit"; - this.chatStrategy = - init?.chatStrategy ?? - (this.model.endsWith("4bit") - ? DeuceChatStrategy.REPLICATE4BITWNEWLINES // With the newer Replicate models they do the system message themselves. - : DeuceChatStrategy.METAWBOS); // With BOS and EOS seems to work best, although they all have problems past a certain point - this.temperature = init?.temperature ?? 0.1; // minimum temperature is 0.01 for Replicate endpoint - this.topP = init?.topP ?? 1; - this.maxTokens = - init?.maxTokens ?? - ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model].contextWindow; // For Replicate, the default is 500 tokens which is too low. - this.replicateSession = init?.replicateSession ?? new ReplicateSession(); - } - - get metadata() { - return { - model: this.model, - temperature: this.temperature, - topP: this.topP, - maxTokens: this.maxTokens, - contextWindow: ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model].contextWindow, - tokenizer: undefined, - }; - } - - mapMessagesToPrompt(messages: ChatMessage[]) { - if (this.chatStrategy === DeuceChatStrategy.A16Z) { - return this.mapMessagesToPromptA16Z(messages); - } else if (this.chatStrategy === DeuceChatStrategy.META) { - return this.mapMessagesToPromptMeta(messages); - } else if (this.chatStrategy === DeuceChatStrategy.METAWBOS) { - return this.mapMessagesToPromptMeta(messages, { withBos: true }); - } else if (this.chatStrategy === DeuceChatStrategy.REPLICATE4BIT) { - return this.mapMessagesToPromptMeta(messages, { - replicate4Bit: true, - withNewlines: true, - }); - } else if (this.chatStrategy === DeuceChatStrategy.REPLICATE4BITWNEWLINES) { - return this.mapMessagesToPromptMeta(messages, { - replicate4Bit: true, - withNewlines: true, - }); - } else { - return this.mapMessagesToPromptMeta(messages); - } - } - - mapMessagesToPromptA16Z(messages: ChatMessage[]) { - return { - prompt: - messages.reduce((acc, message) => { - return ( - (acc && `${acc}\n\n`) + - `${this.mapMessageTypeA16Z(message.role)}${message.content}` - ); - }, "") + "\n\nAssistant:", - //^ Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization - systemPrompt: undefined, - }; - } - - mapMessageTypeA16Z(messageType: MessageType): string { - switch (messageType) { - case "user": - return "User: "; - case "assistant": - return "Assistant: "; - case "system": - return ""; - default: - throw new Error("Unsupported LlamaDeuce message type"); - } - } - - mapMessagesToPromptMeta( - messages: ChatMessage[], - opts?: { - withBos?: boolean; - replicate4Bit?: boolean; - withNewlines?: boolean; - }, - ) { - const { - withBos = false, - replicate4Bit = false, - withNewlines = false, - } = opts ?? {}; - const DEFAULT_SYSTEM_PROMPT = `You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.`; - - const B_SYS = "<<SYS>>\n"; - const E_SYS = "\n<</SYS>>\n\n"; - const B_INST = "[INST]"; - const E_INST = "[/INST]"; - const BOS = "<s>"; - const EOS = "</s>"; - - if (messages.length === 0) { - return { prompt: "", systemPrompt: undefined }; - } - - messages = [...messages]; // so we can use shift without mutating the original array - - let systemPrompt = undefined; - if (messages[0].role === "system") { - const systemMessage = messages.shift()!; - - if (replicate4Bit) { - systemPrompt = systemMessage.content; - } else { - const systemStr = `${B_SYS}${systemMessage.content}${E_SYS}`; - - // TS Bug: https://github.com/microsoft/TypeScript/issues/9998 - // @ts-ignore - if (messages[0].role !== "user") { - throw new Error( - "LlamaDeuce: if there is a system message, the second message must be a user message.", - ); - } - - const userContent = messages[0].content; - - messages[0].content = `${systemStr}${userContent}`; - } - } else { - if (!replicate4Bit) { - messages[0].content = `${B_SYS}${DEFAULT_SYSTEM_PROMPT}${E_SYS}${messages[0].content}`; - } - } - - return { - prompt: messages.reduce((acc, message, index) => { - const content = extractText(message.content); - if (index % 2 === 0) { - return ( - `${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` + - (withNewlines ? "\n" : "") - ); - } else { - return ( - `${acc} ${content.trim()}` + - (withNewlines ? "\n" : " ") + - (withBos ? EOS : "") - ); // Yes, the EOS comes after the space. This is not a mistake. - } - }, ""), - systemPrompt, - }; - } - - chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - @wrapLLMEvent - async chat( - params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream } = params; - const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] - .replicateApi as `${string}/${string}:${string}`; - - const { prompt, systemPrompt } = this.mapMessagesToPrompt(messages); - - const replicateOptions: any = { - input: { - prompt, - system_prompt: systemPrompt, - temperature: this.temperature, - top_p: this.topP, - }, - }; - - if (this.model.endsWith("4bit")) { - replicateOptions.input.max_new_tokens = this.maxTokens; - } else { - replicateOptions.input.max_length = this.maxTokens; - } - - //TODO: Add streaming for this - if (stream) { - throw new Error("Streaming not supported for LlamaDeuce"); - } - - //Non-streaming - const response = await this.replicateSession.replicate.run( - api, - replicateOptions, - ); - return { - raw: response, - message: { - content: (response as Array<string>).join("").trimStart(), - //^ We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function) - role: "assistant", - }, - }; - } -} - -export class Portkey extends BaseLLM { - apiKey?: string = undefined; - baseURL?: string = undefined; - mode?: string = undefined; - llms?: [LLMOptions] | null = undefined; - session: PortkeySession; - - constructor(init?: Partial<Portkey>) { - super(); - this.apiKey = init?.apiKey; - this.baseURL = init?.baseURL; - this.mode = init?.mode; - this.llms = init?.llms; - this.session = getPortkeySession({ - apiKey: this.apiKey, - baseURL: this.baseURL, - llms: this.llms, - mode: this.mode, - }); - } - - get metadata(): LLMMetadata { - throw new Error("metadata not implemented for Portkey"); - } - - chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - @wrapLLMEvent - async chat( - params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream, additionalChatOptions } = params; - if (stream) { - return this.streamChat(messages, additionalChatOptions); - } else { - const bodyParams = additionalChatOptions || {}; - const response = await this.session.portkey.chatCompletions.create({ - messages: messages.map((message) => ({ - content: extractText(message.content), - role: message.role, - })), - ...bodyParams, - }); - - const content = response.choices[0].message?.content ?? ""; - const role = response.choices[0].message?.role || "assistant"; - return { raw: response, message: { content, role: role as MessageType } }; - } - } - - async *streamChat( - messages: ChatMessage[], - params?: Record<string, any>, - ): AsyncIterable<ChatResponseChunk> { - const chunkStream = await this.session.portkey.chatCompletions.create({ - messages: messages.map((message) => ({ - content: extractText(message.content), - role: message.role, - })), - ...params, - stream: true, - }); - - //Indices - let idx_counter: number = 0; - for await (const part of chunkStream) { - //Increment - part.choices[0].index = idx_counter; - const is_done: boolean = - part.choices[0].finish_reason === "stop" ? true : false; - //onLLMStream Callback - - const stream_callback: StreamCallbackResponse = { - index: idx_counter, - isDone: is_done, - // token: part, - }; - getCallbackManager().dispatchEvent("stream", stream_callback); - - idx_counter++; - - yield { raw: part, delta: part.choices[0].delta?.content ?? "" }; - } - return; - } -} diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts index 2ab4c37d7..097b8afdd 100644 --- a/packages/core/src/llm/index.ts +++ b/packages/core/src/llm/index.ts @@ -1,4 +1,3 @@ -export * from "./LLM.js"; export { Anthropic } from "./anthropic.js"; export { FireworksLLM } from "./fireworks.js"; export { Groq } from "./groq.js"; @@ -9,5 +8,12 @@ export { } from "./mistral.js"; export { Ollama } from "./ollama.js"; export * from "./open_ai.js"; +export { Portkey } from "./portkey.js"; +export * from "./replicate_ai.js"; +// Note: The type aliases for replicate are to simplify usage for Llama 2 (we're using replicate for Llama 2 support) +export { + ReplicateChatStrategy as DeuceChatStrategy, + ReplicateLLM as LlamaDeuce, +} from "./replicate_ai.js"; export { TogetherLLM } from "./together.js"; export * from "./types.js"; diff --git a/packages/core/src/llm/portkey.ts b/packages/core/src/llm/portkey.ts index d0d90253e..bb3f47b12 100644 --- a/packages/core/src/llm/portkey.ts +++ b/packages/core/src/llm/portkey.ts @@ -1,7 +1,20 @@ import { getEnv } from "@llamaindex/env"; import _ from "lodash"; import type { LLMOptions } from "portkey-ai"; -import { Portkey } from "portkey-ai"; +import { Portkey as OrigPortKey } from "portkey-ai"; +import { type StreamCallbackResponse } from "../callbacks/CallbackManager.js"; +import { getCallbackManager } from "../internal/settings/CallbackManager.js"; +import { BaseLLM } from "./base.js"; +import type { + ChatMessage, + ChatResponse, + ChatResponseChunk, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + LLMMetadata, + MessageType, +} from "./types.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; interface PortkeyOptions { apiKey?: string; @@ -11,7 +24,7 @@ interface PortkeyOptions { } export class PortkeySession { - portkey: Portkey; + portkey: OrigPortKey; constructor(options: PortkeyOptions = {}) { if (!options.apiKey) { @@ -22,13 +35,13 @@ export class PortkeySession { options.baseURL = getEnv("PORTKEY_BASE_URL") ?? "https://api.portkey.ai"; } - this.portkey = new Portkey({}); + this.portkey = new OrigPortKey({}); this.portkey.llms = [{}]; if (!options.apiKey) { throw new Error("Set Portkey ApiKey in PORTKEY_API_KEY env variable"); } - this.portkey = new Portkey(options); + this.portkey = new OrigPortKey(options); } } @@ -54,3 +67,92 @@ export function getPortkeySession(options: PortkeyOptions = {}) { } return session; } + +export class Portkey extends BaseLLM { + apiKey?: string = undefined; + baseURL?: string = undefined; + mode?: string = undefined; + llms?: [LLMOptions] | null = undefined; + session: PortkeySession; + + constructor(init?: Partial<Portkey>) { + super(); + this.apiKey = init?.apiKey; + this.baseURL = init?.baseURL; + this.mode = init?.mode; + this.llms = init?.llms; + this.session = getPortkeySession({ + apiKey: this.apiKey, + baseURL: this.baseURL, + llms: this.llms, + mode: this.mode, + }); + } + + get metadata(): LLMMetadata { + throw new Error("metadata not implemented for Portkey"); + } + + chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + @wrapLLMEvent + async chat( + params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, + ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { + const { messages, stream, additionalChatOptions } = params; + if (stream) { + return this.streamChat(messages, additionalChatOptions); + } else { + const bodyParams = additionalChatOptions || {}; + const response = await this.session.portkey.chatCompletions.create({ + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), + ...bodyParams, + }); + + const content = response.choices[0].message?.content ?? ""; + const role = response.choices[0].message?.role || "assistant"; + return { raw: response, message: { content, role: role as MessageType } }; + } + } + + async *streamChat( + messages: ChatMessage[], + params?: Record<string, any>, + ): AsyncIterable<ChatResponseChunk> { + const chunkStream = await this.session.portkey.chatCompletions.create({ + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), + ...params, + stream: true, + }); + + //Indices + let idx_counter: number = 0; + for await (const part of chunkStream) { + //Increment + part.choices[0].index = idx_counter; + const is_done: boolean = + part.choices[0].finish_reason === "stop" ? true : false; + //onLLMStream Callback + + const stream_callback: StreamCallbackResponse = { + index: idx_counter, + isDone: is_done, + // token: part, + }; + getCallbackManager().dispatchEvent("stream", stream_callback); + + idx_counter++; + + yield { raw: part, delta: part.choices[0].delta?.content ?? "" }; + } + return; + } +} diff --git a/packages/core/src/llm/replicate_ai.ts b/packages/core/src/llm/replicate_ai.ts index 5c3d51f56..b369a1bf3 100644 --- a/packages/core/src/llm/replicate_ai.ts +++ b/packages/core/src/llm/replicate_ai.ts @@ -1,5 +1,15 @@ import { getEnv } from "@llamaindex/env"; import Replicate from "replicate"; +import { BaseLLM } from "./base.js"; +import type { + ChatMessage, + ChatResponse, + ChatResponseChunk, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + MessageType, +} from "./types.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export class ReplicateSession { replicateKey: string | null = null; @@ -20,12 +30,271 @@ export class ReplicateSession { } } -let defaultReplicateSession: ReplicateSession | null = null; +export const ALL_AVAILABLE_REPLICATE_MODELS = { + // TODO: add more models from replicate + "Llama-2-70b-chat-old": { + contextWindow: 4096, + replicateApi: + "replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", + //^ Previous 70b model. This is also actually 4 bit, although not exllama. + }, + "Llama-2-70b-chat-4bit": { + contextWindow: 4096, + replicateApi: + "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", + //^ Model is based off of exllama 4bit. + }, + "Llama-2-13b-chat-old": { + contextWindow: 4096, + replicateApi: + "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", + }, + //^ Last known good 13b non-quantized model. In future versions they add the SYS and INST tags themselves + "Llama-2-13b-chat-4bit": { + contextWindow: 4096, + replicateApi: + "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, + "Llama-2-7b-chat-old": { + contextWindow: 4096, + replicateApi: + "a16z-infra/llama7b-v2-chat:4f0a4744c7295c024a1de15e1a63c880d3da035fa1f49bfd344fe076074c8eea", + //^ Last (somewhat) known good 7b non-quantized model. In future versions they add the SYS and INST + // tags themselves + // https://github.com/replicate/cog-llama-template/commit/fa5ce83912cf82fc2b9c01a4e9dc9bff6f2ef137 + // Problem is that they fix the max_new_tokens issue in the same commit. :-( + }, + "Llama-2-7b-chat-4bit": { + contextWindow: 4096, + replicateApi: + "meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0", + }, +}; -export function getReplicateSession(replicateKey: string | null = null) { - if (!defaultReplicateSession) { - defaultReplicateSession = new ReplicateSession(replicateKey); +export enum ReplicateChatStrategy { + A16Z = "a16z", + META = "meta", + METAWBOS = "metawbos", + //^ This is not exactly right because SentencePiece puts the BOS and EOS token IDs in after tokenization + // Unfortunately any string only API won't support these properly. + REPLICATE4BIT = "replicate4bit", + //^ To satisfy Replicate's 4 bit models' requirements where they also insert some INST tags + REPLICATE4BITWNEWLINES = "replicate4bitwnewlines", + //^ Replicate's documentation recommends using newlines: https://replicate.com/blog/how-to-prompt-llama +} + +/** + * Replicate LLM implementation used + */ +export class ReplicateLLM extends BaseLLM { + model: keyof typeof ALL_AVAILABLE_REPLICATE_MODELS; + chatStrategy: ReplicateChatStrategy; + temperature: number; + topP: number; + maxTokens?: number; + replicateSession: ReplicateSession; + + constructor(init?: Partial<ReplicateLLM>) { + super(); + this.model = init?.model ?? "Llama-2-70b-chat-4bit"; + this.chatStrategy = + init?.chatStrategy ?? + (this.model.endsWith("4bit") + ? ReplicateChatStrategy.REPLICATE4BITWNEWLINES // With the newer Replicate models they do the system message themselves. + : ReplicateChatStrategy.METAWBOS); // With BOS and EOS seems to work best, although they all have problems past a certain point + this.temperature = init?.temperature ?? 0.1; // minimum temperature is 0.01 for Replicate endpoint + this.topP = init?.topP ?? 1; + this.maxTokens = + init?.maxTokens ?? + ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow; // For Replicate, the default is 500 tokens which is too low. + this.replicateSession = init?.replicateSession ?? new ReplicateSession(); + } + + get metadata() { + return { + model: this.model, + temperature: this.temperature, + topP: this.topP, + maxTokens: this.maxTokens, + contextWindow: ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow, + tokenizer: undefined, + }; + } + + mapMessagesToPrompt(messages: ChatMessage[]) { + if (this.chatStrategy === ReplicateChatStrategy.A16Z) { + return this.mapMessagesToPromptA16Z(messages); + } else if (this.chatStrategy === ReplicateChatStrategy.META) { + return this.mapMessagesToPromptMeta(messages); + } else if (this.chatStrategy === ReplicateChatStrategy.METAWBOS) { + return this.mapMessagesToPromptMeta(messages, { withBos: true }); + } else if (this.chatStrategy === ReplicateChatStrategy.REPLICATE4BIT) { + return this.mapMessagesToPromptMeta(messages, { + replicate4Bit: true, + withNewlines: true, + }); + } else if ( + this.chatStrategy === ReplicateChatStrategy.REPLICATE4BITWNEWLINES + ) { + return this.mapMessagesToPromptMeta(messages, { + replicate4Bit: true, + withNewlines: true, + }); + } else { + return this.mapMessagesToPromptMeta(messages); + } + } + + mapMessagesToPromptA16Z(messages: ChatMessage[]) { + return { + prompt: + messages.reduce((acc, message) => { + return ( + (acc && `${acc}\n\n`) + + `${this.mapMessageTypeA16Z(message.role)}${message.content}` + ); + }, "") + "\n\nAssistant:", + //^ Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization + systemPrompt: undefined, + }; + } + + mapMessageTypeA16Z(messageType: MessageType): string { + switch (messageType) { + case "user": + return "User: "; + case "assistant": + return "Assistant: "; + case "system": + return ""; + default: + throw new Error("Unsupported ReplicateLLM message type"); + } + } + + mapMessagesToPromptMeta( + messages: ChatMessage[], + opts?: { + withBos?: boolean; + replicate4Bit?: boolean; + withNewlines?: boolean; + }, + ) { + const { + withBos = false, + replicate4Bit = false, + withNewlines = false, + } = opts ?? {}; + const DEFAULT_SYSTEM_PROMPT = `You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.`; + + const B_SYS = "<<SYS>>\n"; + const E_SYS = "\n<</SYS>>\n\n"; + const B_INST = "[INST]"; + const E_INST = "[/INST]"; + const BOS = "<s>"; + const EOS = "</s>"; + + if (messages.length === 0) { + return { prompt: "", systemPrompt: undefined }; + } + + messages = [...messages]; // so we can use shift without mutating the original array + + let systemPrompt = undefined; + if (messages[0].role === "system") { + const systemMessage = messages.shift()!; + + if (replicate4Bit) { + systemPrompt = systemMessage.content; + } else { + const systemStr = `${B_SYS}${systemMessage.content}${E_SYS}`; + + // TS Bug: https://github.com/microsoft/TypeScript/issues/9998 + // @ts-ignore + if (messages[0].role !== "user") { + throw new Error( + "ReplicateLLM: if there is a system message, the second message must be a user message.", + ); + } + + const userContent = messages[0].content; + + messages[0].content = `${systemStr}${userContent}`; + } + } else { + if (!replicate4Bit) { + messages[0].content = `${B_SYS}${DEFAULT_SYSTEM_PROMPT}${E_SYS}${messages[0].content}`; + } + } + + return { + prompt: messages.reduce((acc, message, index) => { + const content = extractText(message.content); + if (index % 2 === 0) { + return ( + `${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` + + (withNewlines ? "\n" : "") + ); + } else { + return ( + `${acc} ${content.trim()}` + + (withNewlines ? "\n" : " ") + + (withBos ? EOS : "") + ); // Yes, the EOS comes after the space. This is not a mistake. + } + }, ""), + systemPrompt, + }; } - return defaultReplicateSession; + chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + @wrapLLMEvent + async chat( + params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, + ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { + const { messages, stream } = params; + const api = ALL_AVAILABLE_REPLICATE_MODELS[this.model] + .replicateApi as `${string}/${string}:${string}`; + + const { prompt, systemPrompt } = this.mapMessagesToPrompt(messages); + + const replicateOptions: any = { + input: { + prompt, + system_prompt: systemPrompt, + temperature: this.temperature, + top_p: this.topP, + }, + }; + + if (this.model.endsWith("4bit")) { + replicateOptions.input.max_new_tokens = this.maxTokens; + } else { + replicateOptions.input.max_length = this.maxTokens; + } + + //TODO: Add streaming for this + if (stream) { + throw new Error("Streaming not supported for ReplicateLLM"); + } + + //Non-streaming + const response = await this.replicateSession.replicate.run( + api, + replicateOptions, + ); + return { + raw: response, + message: { + content: (response as Array<string>).join("").trimStart(), + //^ We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function) + role: "assistant", + }, + }; + } } -- GitLab