From 29c6b62ba19190464f4eb88d729a1bcb9542f494 Mon Sep 17 00:00:00 2001
From: Elliot Kang <kkang2097@gmail.com>
Date: Thu, 28 Sep 2023 16:11:13 -0700
Subject: [PATCH] Updated LLM interface

- auto-sets return types based on streaming flag
---
 packages/core/src/llm/LLM.ts | 82 ++++++++++++++++++++++--------------
 1 file changed, 51 insertions(+), 31 deletions(-)

diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts
index 54bfbae11..6cf9c2f62 100644
--- a/packages/core/src/llm/LLM.ts
+++ b/packages/core/src/llm/LLM.ts
@@ -48,27 +48,23 @@ export type CompletionResponse = ChatResponse;
  * Unified language model interface
  */
 export interface LLM {
+  //Whether a LLM has streaming support
+  hasStreaming: boolean;
   /**
    * Get a chat response from the LLM
    * @param messages
+   * 
+   * The return type of chat() and complete() are set by the "streaming" parameter being set to True.
    */
-  chat(messages: ChatMessage[], parentEvent?: Event): Promise<ChatResponse>;
+  chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>
+  (messages: ChatMessage[], parentEvent?: Event, streaming?: T): Promise<R>;
 
   /**
    * Get a prompt completion from the LLM
    * @param prompt the prompt to complete
    */
-  complete(prompt: string, parentEvent?: Event): Promise<CompletionResponse>;
-
-  stream_chat?(
-    messages: ChatMessage[],
-    parentEvent?: Event,
-  ): AsyncGenerator<string, void, unknown>;
-
-  stream_complete?(
-    query: string,
-    parentEvent?: Event,
-  ): AsyncGenerator<string, void, unknown>;
+  complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>
+  (prompt: string, parentEvent?: Event, streaming?: T): Promise<R>;
 }
 
 export const GPT4_MODELS = {
@@ -102,6 +98,7 @@ export class OpenAI implements LLM {
     Partial<OpenAILLM.Chat.CompletionCreateParams>,
     "max_tokens" | "messages" | "model" | "temperature" | "top_p" | "streaming"
   >;
+  hasStreaming: boolean;
 
   // OpenAI session params
   apiKey?: string = undefined;
@@ -129,6 +126,7 @@ export class OpenAI implements LLM {
     this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
     this.additionalChatOptions = init?.additionalChatOptions;
     this.additionalSessionOptions = init?.additionalSessionOptions;
+    this.hasStreaming = init?.hasStreaming ?? true;
 
     if (init?.azure || shouldUseAzure()) {
       const azureConfig = getAzureConfigFromEnv({
@@ -186,10 +184,11 @@ export class OpenAI implements LLM {
     }
   }
 
-  async chat(
+  async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     messages: ChatMessage[],
     parentEvent?: Event,
-  ): Promise<ChatResponse> {
+    streaming?: T,
+  ): Promise<R>  {
     const baseRequestParams: OpenAILLM.Chat.CompletionCreateParams = {
       model: this.model,
       temperature: this.temperature,
@@ -201,6 +200,13 @@ export class OpenAI implements LLM {
       top_p: this.topP,
       ...this.additionalChatOptions,
     };
+    // Streaming
+    if(streaming){
+      if(!this.hasStreaming){
+        throw Error("No streaming support for this LLM.");
+      }
+      return this.stream_chat(messages, parentEvent) as R;
+    }
     // Non-streaming
     const response = await this.session.openai.chat.completions.create({
       ...baseRequestParams,
@@ -208,20 +214,21 @@ export class OpenAI implements LLM {
     });
 
     const content = response.choices[0].message?.content ?? "";
-    return { message: { content, role: response.choices[0].message.role } };
+    return { message: { content, role: response.choices[0].message.role } } as R;
   }
 
-  async complete(
+  async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     prompt: string,
     parentEvent?: Event,
-  ): Promise<CompletionResponse> {
-    return this.chat([{ content: prompt, role: "user" }], parentEvent);
+    streaming?: T
+  ): Promise<R> {
+    return this.chat([{ content: prompt, role: "user" }], parentEvent, streaming);
   }
 
   //We can wrap a stream in a generator to add some additional logging behavior
   //For future edits: syntax for generator type is <typeof Yield, typeof Return, typeof Accept>
   //"typeof Accept" refers to what types you'll accept when you manually call generator.next(<AcceptType>)
-  async *stream_chat(
+  protected async *stream_chat(
     messages: ChatMessage[],
     parentEvent?: Event,
   ): AsyncGenerator<string, void, unknown> {
@@ -280,7 +287,7 @@ export class OpenAI implements LLM {
   }
 
   //Stream_complete doesn't need to be async because it's child function is already async
-  stream_complete(
+  protected stream_complete(
     query: string,
     parentEvent?: Event,
   ): AsyncGenerator<string, void, unknown> {
@@ -348,6 +355,7 @@ export class LlamaDeuce implements LLM {
   topP: number;
   maxTokens?: number;
   replicateSession: ReplicateSession;
+  hasStreaming: boolean;
 
   constructor(init?: Partial<LlamaDeuce>) {
     this.model = init?.model ?? "Llama-2-70b-chat-4bit";
@@ -362,6 +370,7 @@ export class LlamaDeuce implements LLM {
       init?.maxTokens ??
       ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model].contextWindow; // For Replicate, the default is 500 tokens which is too low.
     this.replicateSession = init?.replicateSession ?? new ReplicateSession();
+    this.hasStreaming = init?.hasStreaming ?? false;
   }
 
   mapMessagesToPrompt(messages: ChatMessage[]) {
@@ -468,10 +477,11 @@ If a question does not make any sense, or is not factually coherent, explain why
     };
   }
 
-  async chat(
+  async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     messages: ChatMessage[],
     _parentEvent?: Event,
-  ): Promise<ChatResponse> {
+    streaming?: T
+  ): Promise<R> {
     const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model]
       .replicateApi as `${string}/${string}:${string}`;
 
@@ -492,6 +502,9 @@ If a question does not make any sense, or is not factually coherent, explain why
       replicateOptions.input.max_length = this.maxTokens;
     }
 
+    //TODO: Add streaming for this
+
+    //Non-streaming
     const response = await this.replicateSession.replicate.run(
       api,
       replicateOptions,
@@ -502,13 +515,14 @@ If a question does not make any sense, or is not factually coherent, explain why
         //^ We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function)
         role: "assistant",
       },
-    };
+    } as R;
   }
 
-  async complete(
+  async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     prompt: string,
     parentEvent?: Event,
-  ): Promise<CompletionResponse> {
+    streaming?: T
+  ): Promise<R> {
     return this.chat([{ content: prompt, role: "user" }], parentEvent);
   }
 }
@@ -517,6 +531,8 @@ If a question does not make any sense, or is not factually coherent, explain why
  * Anthropic LLM implementation
  */
 
+
+//TODO: Add streaming for this
 export class Anthropic implements LLM {
   // Per completion Anthropic params
   model: string;
@@ -529,6 +545,7 @@ export class Anthropic implements LLM {
   maxRetries: number;
   timeout?: number;
   session: AnthropicSession;
+  hasStreaming: boolean;
 
   callbackManager?: CallbackManager;
 
@@ -548,6 +565,7 @@ export class Anthropic implements LLM {
         maxRetries: this.maxRetries,
         timeout: this.timeout,
       });
+    this.hasStreaming = init?.hasStreaming ?? true;
 
     this.callbackManager = init?.callbackManager;
   }
@@ -567,10 +585,11 @@ export class Anthropic implements LLM {
     );
   }
 
-  async chat(
+  async chat<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     messages: ChatMessage[],
     parentEvent?: Event | undefined,
-  ): Promise<ChatResponse> {
+    streaming?: T
+  ): Promise<R> {
     const response = await this.session.anthropic.completions.create({
       model: this.model,
       prompt: this.mapMessagesToPrompt(messages),
@@ -583,12 +602,13 @@ export class Anthropic implements LLM {
       message: { content: response.completion.trimStart(), role: "assistant" },
       //^ We're trimming the start because Anthropic often starts with a space in the response
       // That space will be re-added when we generate the next prompt.
-    };
+    } as R;
   }
-  async complete(
+  async complete<T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse>(
     prompt: string,
     parentEvent?: Event | undefined,
-  ): Promise<CompletionResponse> {
-    return this.chat([{ content: prompt, role: "user" }], parentEvent);
+    streaming?: T
+  ): Promise<R> {
+    return this.chat([{ content: prompt, role: "user" }], parentEvent) as R;
   }
 }
-- 
GitLab