From cdf1685f7c59da5cc002ae9a685d101e65684c62 Mon Sep 17 00:00:00 2001 From: Yi Ding <yi.s.ding@gmail.com> Date: Mon, 10 Jul 2023 07:24:59 -0700 Subject: [PATCH] new llm abstraction from Simon --- packages/core/src/ChatEngine.ts | 50 ++++++++-------- .../core/src/{LanguageModel.ts => LLM.ts} | 58 ++++++++++++------- packages/core/src/LLMPredictor.ts | 10 ++-- packages/core/src/Prompt.ts | 6 +- packages/core/src/ServiceContext.ts | 2 +- 5 files changed, 71 insertions(+), 55 deletions(-) rename packages/core/src/{LanguageModel.ts => LLM.ts} (51%) diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index b3aebadce..add9dd7f6 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,9 +1,4 @@ -import { - BaseChatModel, - BaseMessage, - ChatOpenAI, - LLMResult, -} from "./LanguageModel"; +import { BaseChatModel, ChatMessage, OpenAI, ChatResponse } from "./LLM"; import { TextNode } from "./Node"; import { SimplePrompt, @@ -19,29 +14,32 @@ import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; interface ChatEngine { chatRepl(): void; - achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>; + achat(message: string, chatHistory?: ChatMessage[]): Promise<Response>; reset(): void; } export class SimpleChatEngine implements ChatEngine { - chatHistory: BaseMessage[]; + chatHistory: ChatMessage[]; llm: BaseChatModel; constructor(init?: Partial<SimpleChatEngine>) { this.chatHistory = init?.chatHistory ?? []; - this.llm = init?.llm ?? new ChatOpenAI(); + this.llm = init?.llm ?? new OpenAI(); } chatRepl() { throw new Error("Method not implemented."); } - async achat(message: string, chatHistory?: BaseMessage[]): Promise<Response> { + async achat(message: string, chatHistory?: ChatMessage[]): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; - chatHistory.push({ content: message, type: "human" }); + chatHistory.push({ content: message, role: "user" }); const response = await this.llm.agenerate(chatHistory); - chatHistory.push({ content: response.generations[0][0].text, type: "ai" }); + chatHistory.push({ + content: response.generations[0][0].text, + role: "assistant", + }); this.chatHistory = chatHistory; return new Response(response.generations[0][0].text); } @@ -53,13 +51,13 @@ export class SimpleChatEngine implements ChatEngine { export class CondenseQuestionChatEngine implements ChatEngine { queryEngine: BaseQueryEngine; - chatHistory: BaseMessage[]; + chatHistory: ChatMessage[]; serviceContext: ServiceContext; condenseMessagePrompt: SimplePrompt; constructor(init: { queryEngine: BaseQueryEngine; - chatHistory: BaseMessage[]; + chatHistory: ChatMessage[]; serviceContext?: ServiceContext; condenseMessagePrompt?: SimplePrompt; }) { @@ -72,7 +70,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { } private async acondenseQuestion( - chatHistory: BaseMessage[], + chatHistory: ChatMessage[], question: string ) { const chatHistoryStr = messagesToHistoryStr(chatHistory); @@ -88,7 +86,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { async achat( message: string, - chatHistory?: BaseMessage[] | undefined + chatHistory?: ChatMessage[] | undefined ): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; @@ -99,8 +97,8 @@ export class CondenseQuestionChatEngine implements ChatEngine { const response = await this.queryEngine.aquery(condensedQuestion); - chatHistory.push({ content: message, type: "human" }); - chatHistory.push({ content: response.response, type: "ai" }); + chatHistory.push({ content: message, role: "user" }); + chatHistory.push({ content: response.response, role: "assistant" }); return response; } @@ -117,15 +115,15 @@ export class CondenseQuestionChatEngine implements ChatEngine { export class ContextChatEngine implements ChatEngine { retriever: BaseRetriever; chatModel: BaseChatModel; - chatHistory: BaseMessage[]; + chatHistory: ChatMessage[]; constructor(init: { retriever: BaseRetriever; chatModel?: BaseChatModel; - chatHistory?: BaseMessage[]; + chatHistory?: ChatMessage[]; }) { this.retriever = init.retriever; - this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k"); + this.chatModel = init.chatModel ?? new OpenAI("gpt-3.5-turbo-16k"); this.chatHistory = init?.chatHistory ?? []; } @@ -133,21 +131,21 @@ export class ContextChatEngine implements ChatEngine { throw new Error("Method not implemented."); } - async achat(message: string, chatHistory?: BaseMessage[] | undefined) { + async achat(message: string, chatHistory?: ChatMessage[] | undefined) { chatHistory = chatHistory ?? this.chatHistory; const sourceNodesWithScore = await this.retriever.aretrieve(message); - const systemMessage: BaseMessage = { + const systemMessage: ChatMessage = { content: contextSystemPrompt({ context: sourceNodesWithScore .map((r) => (r.node as TextNode).text) .join("\n\n"), }), - type: "system", + role: "system", }; - chatHistory.push({ content: message, type: "human" }); + chatHistory.push({ content: message, role: "user" }); const response = await this.chatModel.agenerate([ systemMessage, @@ -155,7 +153,7 @@ export class ContextChatEngine implements ChatEngine { ]); const text = response.generations[0][0].text; - chatHistory.push({ content: text, type: "ai" }); + chatHistory.push({ content: text, role: "assistant" }); this.chatHistory = chatHistory; diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LLM.ts similarity index 51% rename from packages/core/src/LanguageModel.ts rename to packages/core/src/LLM.ts index 6e8a82ee1..ac94edaab 100644 --- a/packages/core/src/LanguageModel.ts +++ b/packages/core/src/LLM.ts @@ -8,49 +8,65 @@ import { export interface BaseLanguageModel {} -type MessageType = "human" | "ai" | "system" | "generic" | "function"; +type MessageType = "user" | "assistant" | "system" | "generic" | "function"; -export interface BaseMessage { +export interface ChatMessage { content: string; - type: MessageType; + role: MessageType; } -interface Generation { - text: string; - generationInfo?: Record<string, any>; +export interface ChatResponse { + message: ChatMessage; + raw?: Record<string, any>; + delta?: string; } -export interface LLMResult { - generations: Generation[][]; // Each input can have more than one generations -} +// NOTE in case we need CompletionResponse to diverge from ChatResponse in the future +export type CompletionResponse = ChatResponse; -export interface BaseChatModel extends BaseLanguageModel { - agenerate(messages: BaseMessage[]): Promise<LLMResult>; +export interface LLM { + achat(messages: ChatMessage[]): Promise<ChatResponse>; + acomplete(prompt: string): Promise<CompletionResponse>; } -export class ChatOpenAI implements BaseChatModel { +const GPT4_MODELS = { + "gpt-4": 8192, + "gpt-4-32k": 32768, +}; + +const TURBO_MODELS = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16384, +}; + +const ALL_AVAILABLE_MODELS = { + ...GPT4_MODELS, + ...TURBO_MODELS, +}; + +export class OpenAI implements LLM { model: string; - temperature: number = 0.7; - openAIKey: string | null = null; + temperature: number = 0; requestTimeout: number | null = null; maxRetries: number = 6; n: number = 1; maxTokens?: number; - + openAIKey: string | null = null; session: OpenAISession; constructor(model: string = "gpt-3.5-turbo") { + // NOTE default model is different from Python this.model = model; - this.session = getOpenAISession(); + this.session = getOpenAISession(this.openAIKey); } static mapMessageType( type: MessageType ): ChatCompletionRequestMessageRoleEnum { switch (type) { - case "human": + case "user": return "user"; - case "ai": + case "assistant": return "assistant"; case "system": return "system"; @@ -61,14 +77,16 @@ export class ChatOpenAI implements BaseChatModel { } } - async agenerate(messages: BaseMessage[]): Promise<LLMResult> { + async achat(messages: ChatMessage[]): Promise<ChatResponse> {} + + async acomplete(messages: ChatMessage[]): Promise<ChatResponse> { const { data } = await this.session.openai.createChatCompletion({ model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, n: this.n, messages: messages.map((message) => ({ - role: ChatOpenAI.mapMessageType(message.type), + role: OpenAI.mapMessageType(message.role), content: message.content, })), }); diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index da9112311..cd14fe51d 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,4 +1,4 @@ -import { ChatOpenAI } from "./LanguageModel"; +import { OpenAI } from "./LLM"; import { SimplePrompt } from "./Prompt"; // TODO change this to LLM class @@ -15,7 +15,7 @@ export interface BaseLLMPredictor { export class ChatGPTLLMPredictor implements BaseLLMPredictor { llm: string; retryOnThrottling: boolean; - languageModel: ChatOpenAI; + languageModel: OpenAI; constructor( llm: string = "gpt-3.5-turbo", @@ -24,7 +24,7 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { this.llm = llm; this.retryOnThrottling = retryOnThrottling; - this.languageModel = new ChatOpenAI(this.llm); + this.languageModel = new OpenAI(this.llm); } async getLlmMetadata() { @@ -36,10 +36,10 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { input?: Record<string, string> ): Promise<string> { if (typeof prompt === "string") { - const result = await this.languageModel.agenerate([ + const result = await this.languageModel.acomplete([ { content: prompt, - type: "human", + role: "user", }, ]); return result.generations[0][0].text; diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 3cfb6fac4..6eb994281 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,4 +1,4 @@ -import { BaseMessage } from "./LanguageModel"; +import { ChatMessage } from "./LLM"; import { SubQuestion } from "./QuestionGenerator"; import { ToolMetadata } from "./Tool"; @@ -297,10 +297,10 @@ ${question} `; }; -export function messagesToHistoryStr(messages: BaseMessage[]) { +export function messagesToHistoryStr(messages: ChatMessage[]) { return messages.reduce((acc, message) => { acc += acc ? "\n" : ""; - if (message.type === "human") { + if (message.role === "user") { acc += `Human: ${message.content}`; } else { acc += `Assistant: ${message.content}`; diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 9df16f9dc..4afa4fe2b 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,6 +1,6 @@ import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; -import { BaseLanguageModel } from "./LanguageModel"; +import { BaseLanguageModel } from "./LLM"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; -- GitLab