From cdf1685f7c59da5cc002ae9a685d101e65684c62 Mon Sep 17 00:00:00 2001
From: Yi Ding <yi.s.ding@gmail.com>
Date: Mon, 10 Jul 2023 07:24:59 -0700
Subject: [PATCH] new llm abstraction from Simon

---
 packages/core/src/ChatEngine.ts               | 50 ++++++++--------
 .../core/src/{LanguageModel.ts => LLM.ts}     | 58 ++++++++++++-------
 packages/core/src/LLMPredictor.ts             | 10 ++--
 packages/core/src/Prompt.ts                   |  6 +-
 packages/core/src/ServiceContext.ts           |  2 +-
 5 files changed, 71 insertions(+), 55 deletions(-)
 rename packages/core/src/{LanguageModel.ts => LLM.ts} (51%)

diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index b3aebadce..add9dd7f6 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -1,9 +1,4 @@
-import {
-  BaseChatModel,
-  BaseMessage,
-  ChatOpenAI,
-  LLMResult,
-} from "./LanguageModel";
+import { BaseChatModel, ChatMessage, OpenAI, ChatResponse } from "./LLM";
 import { TextNode } from "./Node";
 import {
   SimplePrompt,
@@ -19,29 +14,32 @@ import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
 interface ChatEngine {
   chatRepl(): void;
 
-  achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>;
+  achat(message: string, chatHistory?: ChatMessage[]): Promise<Response>;
 
   reset(): void;
 }
 
 export class SimpleChatEngine implements ChatEngine {
-  chatHistory: BaseMessage[];
+  chatHistory: ChatMessage[];
   llm: BaseChatModel;
 
   constructor(init?: Partial<SimpleChatEngine>) {
     this.chatHistory = init?.chatHistory ?? [];
-    this.llm = init?.llm ?? new ChatOpenAI();
+    this.llm = init?.llm ?? new OpenAI();
   }
 
   chatRepl() {
     throw new Error("Method not implemented.");
   }
 
-  async achat(message: string, chatHistory?: BaseMessage[]): Promise<Response> {
+  async achat(message: string, chatHistory?: ChatMessage[]): Promise<Response> {
     chatHistory = chatHistory ?? this.chatHistory;
-    chatHistory.push({ content: message, type: "human" });
+    chatHistory.push({ content: message, role: "user" });
     const response = await this.llm.agenerate(chatHistory);
-    chatHistory.push({ content: response.generations[0][0].text, type: "ai" });
+    chatHistory.push({
+      content: response.generations[0][0].text,
+      role: "assistant",
+    });
     this.chatHistory = chatHistory;
     return new Response(response.generations[0][0].text);
   }
@@ -53,13 +51,13 @@ export class SimpleChatEngine implements ChatEngine {
 
 export class CondenseQuestionChatEngine implements ChatEngine {
   queryEngine: BaseQueryEngine;
-  chatHistory: BaseMessage[];
+  chatHistory: ChatMessage[];
   serviceContext: ServiceContext;
   condenseMessagePrompt: SimplePrompt;
 
   constructor(init: {
     queryEngine: BaseQueryEngine;
-    chatHistory: BaseMessage[];
+    chatHistory: ChatMessage[];
     serviceContext?: ServiceContext;
     condenseMessagePrompt?: SimplePrompt;
   }) {
@@ -72,7 +70,7 @@ export class CondenseQuestionChatEngine implements ChatEngine {
   }
 
   private async acondenseQuestion(
-    chatHistory: BaseMessage[],
+    chatHistory: ChatMessage[],
     question: string
   ) {
     const chatHistoryStr = messagesToHistoryStr(chatHistory);
@@ -88,7 +86,7 @@ export class CondenseQuestionChatEngine implements ChatEngine {
 
   async achat(
     message: string,
-    chatHistory?: BaseMessage[] | undefined
+    chatHistory?: ChatMessage[] | undefined
   ): Promise<Response> {
     chatHistory = chatHistory ?? this.chatHistory;
 
@@ -99,8 +97,8 @@ export class CondenseQuestionChatEngine implements ChatEngine {
 
     const response = await this.queryEngine.aquery(condensedQuestion);
 
-    chatHistory.push({ content: message, type: "human" });
-    chatHistory.push({ content: response.response, type: "ai" });
+    chatHistory.push({ content: message, role: "user" });
+    chatHistory.push({ content: response.response, role: "assistant" });
 
     return response;
   }
@@ -117,15 +115,15 @@ export class CondenseQuestionChatEngine implements ChatEngine {
 export class ContextChatEngine implements ChatEngine {
   retriever: BaseRetriever;
   chatModel: BaseChatModel;
-  chatHistory: BaseMessage[];
+  chatHistory: ChatMessage[];
 
   constructor(init: {
     retriever: BaseRetriever;
     chatModel?: BaseChatModel;
-    chatHistory?: BaseMessage[];
+    chatHistory?: ChatMessage[];
   }) {
     this.retriever = init.retriever;
-    this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k");
+    this.chatModel = init.chatModel ?? new OpenAI("gpt-3.5-turbo-16k");
     this.chatHistory = init?.chatHistory ?? [];
   }
 
@@ -133,21 +131,21 @@ export class ContextChatEngine implements ChatEngine {
     throw new Error("Method not implemented.");
   }
 
-  async achat(message: string, chatHistory?: BaseMessage[] | undefined) {
+  async achat(message: string, chatHistory?: ChatMessage[] | undefined) {
     chatHistory = chatHistory ?? this.chatHistory;
 
     const sourceNodesWithScore = await this.retriever.aretrieve(message);
 
-    const systemMessage: BaseMessage = {
+    const systemMessage: ChatMessage = {
       content: contextSystemPrompt({
         context: sourceNodesWithScore
           .map((r) => (r.node as TextNode).text)
           .join("\n\n"),
       }),
-      type: "system",
+      role: "system",
     };
 
-    chatHistory.push({ content: message, type: "human" });
+    chatHistory.push({ content: message, role: "user" });
 
     const response = await this.chatModel.agenerate([
       systemMessage,
@@ -155,7 +153,7 @@ export class ContextChatEngine implements ChatEngine {
     ]);
     const text = response.generations[0][0].text;
 
-    chatHistory.push({ content: text, type: "ai" });
+    chatHistory.push({ content: text, role: "assistant" });
 
     this.chatHistory = chatHistory;
 
diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LLM.ts
similarity index 51%
rename from packages/core/src/LanguageModel.ts
rename to packages/core/src/LLM.ts
index 6e8a82ee1..ac94edaab 100644
--- a/packages/core/src/LanguageModel.ts
+++ b/packages/core/src/LLM.ts
@@ -8,49 +8,65 @@ import {
 
 export interface BaseLanguageModel {}
 
-type MessageType = "human" | "ai" | "system" | "generic" | "function";
+type MessageType = "user" | "assistant" | "system" | "generic" | "function";
 
-export interface BaseMessage {
+export interface ChatMessage {
   content: string;
-  type: MessageType;
+  role: MessageType;
 }
 
-interface Generation {
-  text: string;
-  generationInfo?: Record<string, any>;
+export interface ChatResponse {
+  message: ChatMessage;
+  raw?: Record<string, any>;
+  delta?: string;
 }
 
-export interface LLMResult {
-  generations: Generation[][]; // Each input can have more than one generations
-}
+// NOTE in case we need CompletionResponse to diverge from ChatResponse in the future
+export type CompletionResponse = ChatResponse;
 
-export interface BaseChatModel extends BaseLanguageModel {
-  agenerate(messages: BaseMessage[]): Promise<LLMResult>;
+export interface LLM {
+  achat(messages: ChatMessage[]): Promise<ChatResponse>;
+  acomplete(prompt: string): Promise<CompletionResponse>;
 }
 
-export class ChatOpenAI implements BaseChatModel {
+const GPT4_MODELS = {
+  "gpt-4": 8192,
+  "gpt-4-32k": 32768,
+};
+
+const TURBO_MODELS = {
+  "gpt-3.5-turbo": 4096,
+  "gpt-3.5-turbo-16k": 16384,
+};
+
+const ALL_AVAILABLE_MODELS = {
+  ...GPT4_MODELS,
+  ...TURBO_MODELS,
+};
+
+export class OpenAI implements LLM {
   model: string;
-  temperature: number = 0.7;
-  openAIKey: string | null = null;
+  temperature: number = 0;
   requestTimeout: number | null = null;
   maxRetries: number = 6;
   n: number = 1;
   maxTokens?: number;
-
+  openAIKey: string | null = null;
   session: OpenAISession;
 
   constructor(model: string = "gpt-3.5-turbo") {
+    // NOTE default model is different from Python
     this.model = model;
-    this.session = getOpenAISession();
+    this.session = getOpenAISession(this.openAIKey);
   }
 
   static mapMessageType(
     type: MessageType
   ): ChatCompletionRequestMessageRoleEnum {
     switch (type) {
-      case "human":
+      case "user":
         return "user";
-      case "ai":
+      case "assistant":
         return "assistant";
       case "system":
         return "system";
@@ -61,14 +77,16 @@ export class ChatOpenAI implements BaseChatModel {
     }
   }
 
-  async agenerate(messages: BaseMessage[]): Promise<LLMResult> {
+  async achat(messages: ChatMessage[]): Promise<ChatResponse> {}
+
+  async acomplete(messages: ChatMessage[]): Promise<ChatResponse> {
     const { data } = await this.session.openai.createChatCompletion({
       model: this.model,
       temperature: this.temperature,
       max_tokens: this.maxTokens,
       n: this.n,
       messages: messages.map((message) => ({
-        role: ChatOpenAI.mapMessageType(message.type),
+        role: OpenAI.mapMessageType(message.role),
         content: message.content,
       })),
     });
diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts
index da9112311..cd14fe51d 100644
--- a/packages/core/src/LLMPredictor.ts
+++ b/packages/core/src/LLMPredictor.ts
@@ -1,4 +1,4 @@
-import { ChatOpenAI } from "./LanguageModel";
+import { OpenAI } from "./LLM";
 import { SimplePrompt } from "./Prompt";
 
 // TODO change this to LLM class
@@ -15,7 +15,7 @@ export interface BaseLLMPredictor {
 export class ChatGPTLLMPredictor implements BaseLLMPredictor {
   llm: string;
   retryOnThrottling: boolean;
-  languageModel: ChatOpenAI;
+  languageModel: OpenAI;
 
   constructor(
     llm: string = "gpt-3.5-turbo",
@@ -24,7 +24,7 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {
     this.llm = llm;
     this.retryOnThrottling = retryOnThrottling;
 
-    this.languageModel = new ChatOpenAI(this.llm);
+    this.languageModel = new OpenAI(this.llm);
   }
 
   async getLlmMetadata() {
@@ -36,10 +36,10 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {
     input?: Record<string, string>
   ): Promise<string> {
     if (typeof prompt === "string") {
-      const result = await this.languageModel.agenerate([
+      const result = await this.languageModel.acomplete([
         {
           content: prompt,
-          type: "human",
+          role: "user",
         },
       ]);
       return result.generations[0][0].text;
diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts
index 3cfb6fac4..6eb994281 100644
--- a/packages/core/src/Prompt.ts
+++ b/packages/core/src/Prompt.ts
@@ -1,4 +1,4 @@
-import { BaseMessage } from "./LanguageModel";
+import { ChatMessage } from "./LLM";
 import { SubQuestion } from "./QuestionGenerator";
 import { ToolMetadata } from "./Tool";
 
@@ -297,10 +297,10 @@ ${question}
 `;
 };
 
-export function messagesToHistoryStr(messages: BaseMessage[]) {
+export function messagesToHistoryStr(messages: ChatMessage[]) {
   return messages.reduce((acc, message) => {
     acc += acc ? "\n" : "";
-    if (message.type === "human") {
+    if (message.role === "user") {
       acc += `Human: ${message.content}`;
     } else {
       acc += `Assistant: ${message.content}`;
diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts
index 9df16f9dc..4afa4fe2b 100644
--- a/packages/core/src/ServiceContext.ts
+++ b/packages/core/src/ServiceContext.ts
@@ -1,6 +1,6 @@
 import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
 import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
-import { BaseLanguageModel } from "./LanguageModel";
+import { BaseLanguageModel } from "./LLM";
 import { NodeParser, SimpleNodeParser } from "./NodeParser";
 import { PromptHelper } from "./PromptHelper";
 
-- 
GitLab