diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 54bfbae11e7927a05412f938ccc83759726a35db..6cf9c2f6293b43579483c568633777115422728d 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -48,27 +48,23 @@ export type CompletionResponse = ChatResponse; * Unified language model interface */ export interface LLM { + //Whether a LLM has streaming support + hasStreaming: boolean; /** * Get a chat response from the LLM * @param messages + * + * The return type of chat() and complete() are set by the "streaming" parameter being set to True. */ - chat(messages: ChatMessage[], parentEvent?: Event): Promise<ChatResponse>; + chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse> + (messages: ChatMessage[], parentEvent?: Event, streaming?: T): Promise<R>; /** * Get a prompt completion from the LLM * @param prompt the prompt to complete */ - complete(prompt: string, parentEvent?: Event): Promise<CompletionResponse>; - - stream_chat?( - messages: ChatMessage[], - parentEvent?: Event, - ): AsyncGenerator<string, void, unknown>; - - stream_complete?( - query: string, - parentEvent?: Event, - ): AsyncGenerator<string, void, unknown>; + complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse> + (prompt: string, parentEvent?: Event, streaming?: T): Promise<R>; } export const GPT4_MODELS = { @@ -102,6 +98,7 @@ export class OpenAI implements LLM { Partial<OpenAILLM.Chat.CompletionCreateParams>, "max_tokens" | "messages" | "model" | "temperature" | "top_p" | "streaming" >; + hasStreaming: boolean; // OpenAI session params apiKey?: string = undefined; @@ -129,6 +126,7 @@ export class OpenAI implements LLM { this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds this.additionalChatOptions = init?.additionalChatOptions; this.additionalSessionOptions = init?.additionalSessionOptions; + this.hasStreaming = init?.hasStreaming ?? true; if (init?.azure || shouldUseAzure()) { const azureConfig = getAzureConfigFromEnv({ @@ -186,10 +184,11 @@ export class OpenAI implements LLM { } } - async chat( + async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( messages: ChatMessage[], parentEvent?: Event, - ): Promise<ChatResponse> { + streaming?: T, + ): Promise<R> { const baseRequestParams: OpenAILLM.Chat.CompletionCreateParams = { model: this.model, temperature: this.temperature, @@ -201,6 +200,13 @@ export class OpenAI implements LLM { top_p: this.topP, ...this.additionalChatOptions, }; + // Streaming + if(streaming){ + if(!this.hasStreaming){ + throw Error("No streaming support for this LLM."); + } + return this.stream_chat(messages, parentEvent) as R; + } // Non-streaming const response = await this.session.openai.chat.completions.create({ ...baseRequestParams, @@ -208,20 +214,21 @@ export class OpenAI implements LLM { }); const content = response.choices[0].message?.content ?? ""; - return { message: { content, role: response.choices[0].message.role } }; + return { message: { content, role: response.choices[0].message.role } } as R; } - async complete( + async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( prompt: string, parentEvent?: Event, - ): Promise<CompletionResponse> { - return this.chat([{ content: prompt, role: "user" }], parentEvent); + streaming?: T + ): Promise<R> { + return this.chat([{ content: prompt, role: "user" }], parentEvent, streaming); } //We can wrap a stream in a generator to add some additional logging behavior //For future edits: syntax for generator type is <typeof Yield, typeof Return, typeof Accept> //"typeof Accept" refers to what types you'll accept when you manually call generator.next(<AcceptType>) - async *stream_chat( + protected async *stream_chat( messages: ChatMessage[], parentEvent?: Event, ): AsyncGenerator<string, void, unknown> { @@ -280,7 +287,7 @@ export class OpenAI implements LLM { } //Stream_complete doesn't need to be async because it's child function is already async - stream_complete( + protected stream_complete( query: string, parentEvent?: Event, ): AsyncGenerator<string, void, unknown> { @@ -348,6 +355,7 @@ export class LlamaDeuce implements LLM { topP: number; maxTokens?: number; replicateSession: ReplicateSession; + hasStreaming: boolean; constructor(init?: Partial<LlamaDeuce>) { this.model = init?.model ?? "Llama-2-70b-chat-4bit"; @@ -362,6 +370,7 @@ export class LlamaDeuce implements LLM { init?.maxTokens ?? ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model].contextWindow; // For Replicate, the default is 500 tokens which is too low. this.replicateSession = init?.replicateSession ?? new ReplicateSession(); + this.hasStreaming = init?.hasStreaming ?? false; } mapMessagesToPrompt(messages: ChatMessage[]) { @@ -468,10 +477,11 @@ If a question does not make any sense, or is not factually coherent, explain why }; } - async chat( + async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( messages: ChatMessage[], _parentEvent?: Event, - ): Promise<ChatResponse> { + streaming?: T + ): Promise<R> { const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] .replicateApi as `${string}/${string}:${string}`; @@ -492,6 +502,9 @@ If a question does not make any sense, or is not factually coherent, explain why replicateOptions.input.max_length = this.maxTokens; } + //TODO: Add streaming for this + + //Non-streaming const response = await this.replicateSession.replicate.run( api, replicateOptions, @@ -502,13 +515,14 @@ If a question does not make any sense, or is not factually coherent, explain why //^ We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function) role: "assistant", }, - }; + } as R; } - async complete( + async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( prompt: string, parentEvent?: Event, - ): Promise<CompletionResponse> { + streaming?: T + ): Promise<R> { return this.chat([{ content: prompt, role: "user" }], parentEvent); } } @@ -517,6 +531,8 @@ If a question does not make any sense, or is not factually coherent, explain why * Anthropic LLM implementation */ + +//TODO: Add streaming for this export class Anthropic implements LLM { // Per completion Anthropic params model: string; @@ -529,6 +545,7 @@ export class Anthropic implements LLM { maxRetries: number; timeout?: number; session: AnthropicSession; + hasStreaming: boolean; callbackManager?: CallbackManager; @@ -548,6 +565,7 @@ export class Anthropic implements LLM { maxRetries: this.maxRetries, timeout: this.timeout, }); + this.hasStreaming = init?.hasStreaming ?? true; this.callbackManager = init?.callbackManager; } @@ -567,10 +585,11 @@ export class Anthropic implements LLM { ); } - async chat( + async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( messages: ChatMessage[], parentEvent?: Event | undefined, - ): Promise<ChatResponse> { + streaming?: T + ): Promise<R> { const response = await this.session.anthropic.completions.create({ model: this.model, prompt: this.mapMessagesToPrompt(messages), @@ -583,12 +602,13 @@ export class Anthropic implements LLM { message: { content: response.completion.trimStart(), role: "assistant" }, //^ We're trimming the start because Anthropic often starts with a space in the response // That space will be re-added when we generate the next prompt. - }; + } as R; } - async complete( + async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>( prompt: string, parentEvent?: Event | undefined, - ): Promise<CompletionResponse> { - return this.chat([{ content: prompt, role: "user" }], parentEvent); + streaming?: T + ): Promise<R> { + return this.chat([{ content: prompt, role: "user" }], parentEvent) as R; } }