diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index ebd8a1525e6738c4a4675a9677aa9c3fffc788e6..00b73e6a8f02a5c2a059210d23b77f27e315fb35 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,4 +1,4 @@ -import { BaseChatModel, ChatMessage, OpenAI, ChatResponse } from "./LLM"; +import { ChatMessage, OpenAI, ChatResponse, LLM } from "./LLM"; import { TextNode } from "./Node"; import { SimplePrompt, @@ -23,7 +23,7 @@ interface ChatEngine { export class SimpleChatEngine implements ChatEngine { chatHistory: ChatMessage[]; - llm: BaseChatModel; + llm: LLM; constructor(init?: Partial<SimpleChatEngine>) { this.chatHistory = init?.chatHistory ?? []; @@ -37,13 +37,10 @@ export class SimpleChatEngine implements ChatEngine { async achat(message: string, chatHistory?: ChatMessage[]): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; chatHistory.push({ content: message, role: "user" }); - const response = await this.llm.agenerate(chatHistory); - chatHistory.push({ - content: response.generations[0][0].text, - role: "assistant", - }); + const response = await this.llm.achat(chatHistory); + chatHistory.push(response.message); this.chatHistory = chatHistory; - return new Response(response.generations[0][0].text); + return new Response(response.message.content); } reset() { @@ -116,16 +113,17 @@ export class CondenseQuestionChatEngine implements ChatEngine { export class ContextChatEngine implements ChatEngine { retriever: BaseRetriever; - chatModel: BaseChatModel; + chatModel: OpenAI; chatHistory: ChatMessage[]; constructor(init: { retriever: BaseRetriever; - chatModel?: BaseChatModel; + chatModel?: OpenAI; chatHistory?: ChatMessage[]; }) { this.retriever = init.retriever; - this.chatModel = init.chatModel ?? new OpenAI("gpt-3.5-turbo-16k"); + this.chatModel = + init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = init?.chatHistory ?? []; } @@ -157,18 +155,16 @@ export class ContextChatEngine implements ChatEngine { chatHistory.push({ content: message, role: "user" }); - const response = await this.chatModel.agenerate( + const response = await this.chatModel.achat( [systemMessage, ...chatHistory], parentEvent ); - const text = response.generations[0][0].text; - - chatHistory.push({ content: text, role: "assistant" }); + chatHistory.push(response.message); this.chatHistory = chatHistory; return new Response( - text, + response.message.content, sourceNodesWithScore.map((r) => r.node) ); } diff --git a/packages/core/src/LLM.ts b/packages/core/src/LLM.ts index 0dcbcef406bbc55e8257fa4d7e511a5f29590486..56d1d19354443f4094de2121d2bc5edbf508e4fa 100644 --- a/packages/core/src/LLM.ts +++ b/packages/core/src/LLM.ts @@ -1,3 +1,4 @@ +import { Key } from "readline"; import { CallbackManager, Event } from "./callbacks/CallbackManager"; import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream"; import { @@ -7,8 +8,6 @@ import { getOpenAISession, } from "./openai"; -export interface BaseLanguageModel {} - type MessageType = "user" | "assistant" | "system" | "generic" | "function"; export interface ChatMessage { @@ -30,47 +29,45 @@ export interface LLM { acomplete(prompt: string): Promise<CompletionResponse>; } -const GPT4_MODELS = { +export const GPT4_MODELS = { "gpt-4": 8192, "gpt-4-32k": 32768, }; -const TURBO_MODELS = { +export const TURBO_MODELS = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-16k": 16384, }; -const ALL_AVAILABLE_MODELS = { +export const ALL_AVAILABLE_MODELS = { ...GPT4_MODELS, ...TURBO_MODELS, }; export class OpenAI implements LLM { - model: string; - temperature: number = 0; - requestTimeout: number | null = null; - maxRetries: number = 6; + model: keyof typeof ALL_AVAILABLE_MODELS; + temperature: number; + requestTimeout: number | null; + maxRetries: number; n: number = 1; maxTokens?: number; openAIKey: string | null = null; session: OpenAISession; callbackManager?: CallbackManager; - constructor({ - model = "gpt-3.5-turbo", - callbackManager, - }: { - model: string; - callbackManager?: CallbackManager; - }) { - this.model = model; - this.callbackManager = callbackManager; - this.session = getOpenAISession(); + constructor(init?: Partial<OpenAI>) { + this.model = init?.model ?? "gpt-3.5-turbo"; + this.temperature = init?.temperature ?? 0; + this.requestTimeout = init?.requestTimeout ?? null; + this.maxRetries = init?.maxRetries ?? 10; + this.maxTokens = + init?.maxTokens ?? Math.floor(ALL_AVAILABLE_MODELS[this.model] / 2); + this.openAIKey = init?.openAIKey ?? null; + this.session = init?.session ?? getOpenAISession(); + this.callbackManager = init?.callbackManager; } - static mapMessageType( - type: MessageType - ): ChatCompletionRequestMessageRoleEnum { + mapMessageType(type: MessageType): ChatCompletionRequestMessageRoleEnum { switch (type) { case "user": return "user"; @@ -85,26 +82,23 @@ export class OpenAI implements LLM { } } - async achat(messages: ChatMessage[]): Promise<ChatResponse> {} - - async acomplete(messages: ChatMessage[]): Promise<ChatResponse> { - const { data } = await this.session.openai.createChatCompletion({ - async agenerate( - messages: BaseMessage[], + async achat( + messages: ChatMessage[], parentEvent?: Event - ): Promise<LLMResult> { + ): Promise<ChatResponse> { const baseRequestParams: CreateChatCompletionRequest = { model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, n: this.n, messages: messages.map((message) => ({ - role: OpenAI.mapMessageType(message.role), + role: this.mapMessageType(message.role), content: message.content, })), }; if (this.callbackManager?.onLLMStream) { + // Streaming const response = await this.session.openai.createChatCompletion( { ...baseRequestParams, @@ -112,20 +106,28 @@ export class OpenAI implements LLM { }, { responseType: "stream" } ); + const fullResponse = await aHandleOpenAIStream({ response, onLLMStream: this.callbackManager.onLLMStream, parentEvent, }); - return { generations: [[{ text: fullResponse }]] }; - } + return { message: { content: fullResponse, role: "assistant" } }; + } else { + // Non-streaming + const response = await this.session.openai.createChatCompletion( + baseRequestParams + ); - const response = await this.session.openai.createChatCompletion( - baseRequestParams - ); + const content = response.data.choices[0].message?.content ?? ""; + return { message: { content, role: "assistant" } }; + } + } - const { data } = response; - const content = data.choices[0].message?.content ?? ""; - return { generations: [[{ text: content }]] }; + async acomplete( + prompt: string, + parentEvent?: Event + ): Promise<CompletionResponse> { + return this.achat([{ content: prompt, role: "user" }], parentEvent); } } diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index f0fe7b9af6567dc3f5aa4c9c2569976cb2bb5527..bb7d555737c1a6c3a5af6be878592d598cb75504 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,4 +1,4 @@ -import { OpenAI } from "./LLM"; +import { ALL_AVAILABLE_MODELS, OpenAI } from "./LLM"; import { SimplePrompt } from "./Prompt"; import { CallbackManager, Event } from "./callbacks/CallbackManager"; @@ -14,21 +14,12 @@ export interface BaseLLMPredictor { // TODO change this to LLM class export class ChatGPTLLMPredictor implements BaseLLMPredictor { - model: string; + model: keyof typeof ALL_AVAILABLE_MODELS; retryOnThrottling: boolean; languageModel: OpenAI; callbackManager?: CallbackManager; - constructor( - props: - | { - model?: string; - retryOnThrottling?: boolean; - callbackManager?: CallbackManager; - languageModel?: OpenAI; - } - | undefined = undefined - ) { + constructor(props?: Partial<ChatGPTLLMPredictor>) { const { model = "gpt-3.5-turbo", retryOnThrottling = true, @@ -57,14 +48,8 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { parentEvent?: Event ): Promise<string> { if (typeof prompt === "string") { - const result = await this.languageModel.acomplete([ - { - content: prompt, - role: "user", - }, - parentEvent, - ]); - return result.generations[0][0].text; + const result = await this.languageModel.acomplete(prompt, parentEvent); + return result.message.content; } else { return this.apredict(prompt(input ?? {})); } diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 69f6a8a4b982080a7d89d794d46583150ec4ef3a..931e5118cd3678b2f31049b93b6540e96651a928 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,6 +1,6 @@ import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; +import { OpenAI } from "./LLM"; import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; -import { BaseLanguageModel } from "./LLM"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; import { CallbackManager } from "./callbacks/CallbackManager"; @@ -16,7 +16,7 @@ export interface ServiceContext { export interface ServiceContextOptions { llmPredictor?: BaseLLMPredictor; - llm?: BaseLanguageModel; + llm?: OpenAI; promptHelper?: PromptHelper; embedModel?: BaseEmbedding; nodeParser?: NodeParser; diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 26bcbee2f7252baffdd218a14ddaf0db3bdede86..e66d3e6501c386899a586cff2bca1289747fe169 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -1,6 +1,6 @@ import { VectorStoreIndex } from "../BaseIndex"; import { OpenAIEmbedding } from "../Embedding"; -import { ChatOpenAI } from "../LanguageModel"; +import { OpenAI } from "../LLM"; import { Document } from "../Node"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { @@ -35,7 +35,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }); - const languageModel = new ChatOpenAI({ + const languageModel = new OpenAI({ model: "gpt-3.5-turbo", callbackManager, }); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 67631a9acded857a2b483dce1826484977d78aa7..21fd001c3f846ffc82df8207635fe3d88a219bb3 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -1,19 +1,19 @@ import { OpenAIEmbedding } from "../../Embedding"; import { globalsHelper } from "../../GlobalsHelper"; -import { BaseMessage, ChatOpenAI } from "../../LanguageModel"; +import { ChatMessage, OpenAI } from "../../LLM"; import { CallbackManager, Event } from "../../callbacks/CallbackManager"; export function mockLlmGeneration({ languageModel, callbackManager, }: { - languageModel: ChatOpenAI; + languageModel: OpenAI; callbackManager: CallbackManager; }) { jest - .spyOn(languageModel, "agenerate") + .spyOn(languageModel, "achat") .mockImplementation( - async (messages: BaseMessage[], parentEvent?: Event) => { + async (messages: ChatMessage[], parentEvent?: Event) => { const text = "MOCK_TOKEN_1-MOCK_TOKEN_2"; const event = globalsHelper.createEvent({ parentEvent, @@ -51,7 +51,10 @@ export function mockLlmGeneration({ } return new Promise((resolve) => { resolve({ - generations: [[{ text }]], + message: { + content: text, + role: "assistant", + }, }); }); }