diff --git a/apps/simple/llamadeuce.ts b/apps/simple/llamadeuce.ts new file mode 100644 index 0000000000000000000000000000000000000000..9edae4437449c727343bda314e6a09728e619bd2 --- /dev/null +++ b/apps/simple/llamadeuce.ts @@ -0,0 +1,7 @@ +import { LlamaDeuce } from "llamaindex/src/llm/LLM"; + +(async () => { + const deuce = new LlamaDeuce(); + const result = await deuce.chat([{ content: "Hello, world!", role: "user" }]); + console.log(result); +})(); diff --git a/apps/simple/openai.ts b/apps/simple/openai.ts index 3019deacf1f645ce2dee6014227aff995c78561f..12e91f67a8f4461aa5fb68b6185642ebdff6066d 100644 --- a/apps/simple/openai.ts +++ b/apps/simple/openai.ts @@ -1,6 +1,6 @@ // @ts-ignore import process from "node:process"; -import { Configuration, OpenAIWrapper } from "llamaindex/src/openai"; +import { Configuration, OpenAIWrapper } from "llamaindex/src/llm/openai"; (async () => { const configuration = new Configuration({ diff --git a/packages/core/package.json b/packages/core/package.json index a0f7f4ee69bde9b9d3393a3042122a488d5e2b94..36799d68fd75c6019ff5a254412d249575d4e2a6 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -7,6 +7,7 @@ "lodash": "^4.17.21", "openai": "^3.3.0", "pdf-parse": "^1.1.1", + "replicate": "^0.12.3", "tiktoken-node": "^0.0.6", "uuid": "^9.0.0", "wink-nlp": "^1.14.1" diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 94b9362e1cdd8aa9aa708b7bfec65fdf9ea0faf6..cc37aa2e6c0f668bb717ab357678f15a3300830e 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,4 +1,4 @@ -import { ChatMessage, OpenAI, ChatResponse, LLM } from "./LLM"; +import { ChatMessage, OpenAI, ChatResponse, LLM } from "./llm/LLM"; import { TextNode } from "./Node"; import { SimplePrompt, diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts index 8e3ded3a2f88390b69fc49a5f6757b38766be127..ec50ca4fa1efb268d252fd843fd32e0f256a62a3 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/Embedding.ts @@ -1,5 +1,5 @@ import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; -import { OpenAISession, getOpenAISession } from "./openai"; +import { OpenAISession, getOpenAISession } from "./llm/openai"; import { VectorStoreQueryMode } from "./storage/vectorStore/types"; /** diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index a82c4c517c02e2550f5d9a2f6c1cd6b3bb693298..0a1c6c0b02b6be8dd4bd51c35faa9a658114f4d6 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,4 +1,4 @@ -import { ChatMessage } from "./LLM"; +import { ChatMessage } from "./llm/LLM"; import { SubQuestion } from "./QuestionGenerator"; import { ToolMetadata } from "./Tool"; diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index 46bdb60ff0f307e80302707a5d0cc5ed8868c1b5..dd669c2236edc969fffc2336872f4036689a3e9b 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -1,4 +1,4 @@ -import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; +import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./llm/LLMPredictor"; import { BaseOutputParser, StructuredOutput, diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index 1077310f9560d1c63d262ecc7a14366fb812ecdd..eeecbbd5958d60755ad28dd33c35a439d91a3d19 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,4 +1,4 @@ -import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./LLMPredictor"; +import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./llm/LLMPredictor"; import { MetadataMode, NodeWithScore } from "./Node"; import { SimplePrompt, diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 01e13e9633217c51574a8dd109d5653e72a8e1f9..51423c917bfec3db9774c1b2670599c98299a8d6 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,6 +1,6 @@ import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; -import { OpenAI } from "./LLM"; -import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; +import { OpenAI } from "./llm/LLM"; +import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./llm/LLMPredictor"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; import { CallbackManager } from "./callbacks/CallbackManager"; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 9209dd14b1be5781cf3f4745ed5b6bed64a7d7e8..2218e2d5d221e6aedd3e52a9e82625eebb591e2f 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,8 +2,8 @@ export * from "./ChatEngine"; export * from "./constants"; export * from "./Embedding"; export * from "./GlobalsHelper"; -export * from "./LLM"; -export * from "./LLMPredictor"; +export * from "./llm/LLM"; +export * from "./llm/LLMPredictor"; export * from "./Node"; export * from "./NodeParser"; // export * from "./openai"; Don't export OpenAIWrapper diff --git a/packages/core/src/LLM.ts b/packages/core/src/llm/LLM.ts similarity index 52% rename from packages/core/src/LLM.ts rename to packages/core/src/llm/LLM.ts index 0913d64fc0bb894c1e32aae92f81488ad2b3d361..5a721093bc4093616c5d898b0840c6fea01d9f36 100644 --- a/packages/core/src/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -1,12 +1,12 @@ -import { Key } from "readline"; -import { CallbackManager, Event } from "./callbacks/CallbackManager"; -import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream"; +import { CallbackManager, Event } from "../callbacks/CallbackManager"; +import { aHandleOpenAIStream } from "../callbacks/utility/aHandleOpenAIStream"; import { ChatCompletionRequestMessageRoleEnum, CreateChatCompletionRequest, OpenAISession, getOpenAISession, } from "./openai"; +import { ReplicateSession } from "./replicate"; type MessageType = "user" | "assistant" | "system" | "generic" | "function"; @@ -42,19 +42,19 @@ export interface LLM { } export const GPT4_MODELS = { - "gpt-4": 8192, - "gpt-4-32k": 32768, + "gpt-4": { contextWindow: 8192 }, + "gpt-4-32k": { contextWindow: 32768 }, }; export const TURBO_MODELS = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-16k": 16384, + "gpt-3.5-turbo": { contextWindow: 4097 }, + "gpt-3.5-turbo-16k": { contextWindow: 16384 }, }; /** * We currently support GPT-3.5 and GPT-4 models */ -export const ALL_AVAILABLE_MODELS = { +export const ALL_AVAILABLE_OPENAI_MODELS = { ...GPT4_MODELS, ...TURBO_MODELS, }; @@ -63,14 +63,13 @@ export const ALL_AVAILABLE_MODELS = { * OpenAI LLM implementation */ export class OpenAI implements LLM { - model: keyof typeof ALL_AVAILABLE_MODELS; + model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS; temperature: number; - requestTimeout: number | null; - maxRetries: number; n: number = 1; maxTokens?: number; - openAIKey: string | null = null; session: OpenAISession; + maxRetries: number; + requestTimeout: number | null; callbackManager?: CallbackManager; constructor(init?: Partial<OpenAI>) { @@ -79,13 +78,14 @@ export class OpenAI implements LLM { this.requestTimeout = init?.requestTimeout ?? null; this.maxRetries = init?.maxRetries ?? 10; this.maxTokens = init?.maxTokens ?? undefined; - this.openAIKey = init?.openAIKey ?? null; this.session = init?.session ?? getOpenAISession(); this.callbackManager = init?.callbackManager; } - mapMessageType(type: MessageType): ChatCompletionRequestMessageRoleEnum { - switch (type) { + mapMessageType( + messageType: MessageType + ): ChatCompletionRequestMessageRoleEnum { + switch (messageType) { case "user": return "user"; case "assistant": @@ -148,3 +148,83 @@ export class OpenAI implements LLM { return this.chat([{ content: prompt, role: "user" }], parentEvent); } } + +export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { + "Llama-2-70b-chat": { + contextWindow: 4000, + replicateApi: + "replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", + }, + "Llama-2-13b-chat": { + contextWindow: 4000, + replicateApi: + "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", + }, + "Llama-2-7b-chat": { + contextWindow: 4000, + replicateApi: + "a16z-infra/llama7b-v2-chat:4f0a4744c7295c024a1de15e1a63c880d3da035fa1f49bfd344fe076074c8eea", + }, +}; + +/** + * Llama2 LLM implementation + */ +export class LlamaDeuce implements LLM { + model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; + temperature: number; + maxTokens?: number; + session: ReplicateSession; + + constructor(init?: Partial<LlamaDeuce>) { + this.model = init?.model ?? "Llama-2-70b-chat"; + this.temperature = init?.temperature ?? 0; + this.maxTokens = init?.maxTokens ?? undefined; + this.session = init?.session ?? new ReplicateSession(); + } + + mapMessageType(messageType: MessageType): string { + switch (messageType) { + case "user": + return "User: "; + case "assistant": + return "Assistant: "; + case "system": + return ""; + default: + throw new Error("Unsupported LlamaDeuce message type"); + } + } + + async chat( + messages: ChatMessage[], + _parentEvent?: Event + ): Promise<ChatResponse> { + const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] + .replicateApi as `${string}/${string}:${string}`; + const response = await this.session.replicate.run(api, { + input: { + prompt: + messages.reduce((acc, message) => { + return ( + (acc && `${acc}\n\n`) + + `${this.mapMessageType(message.role)}${message.content}` + ); + }, "") + "\n\nAssistant:", // Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization + }, + }); + return { + message: { + content: (response as Array<string>).join(""), + role: "assistant", + }, + }; + } + + async complete( + prompt: string, + parentEvent?: Event + ): Promise<CompletionResponse> { + return this.chat([{ content: prompt, role: "system" }], parentEvent); // Using system prompt here to avoid giving it a previx + } +} diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/llm/LLMPredictor.ts similarity index 87% rename from packages/core/src/LLMPredictor.ts rename to packages/core/src/llm/LLMPredictor.ts index 5bd95ec03c1beeb179de8189f461e9318194ecb8..8dcdeb83ed6228b4500ff9acecbe3e1d33d81a22 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/llm/LLMPredictor.ts @@ -1,6 +1,6 @@ -import { ALL_AVAILABLE_MODELS, OpenAI } from "./LLM"; -import { SimplePrompt } from "./Prompt"; -import { CallbackManager, Event } from "./callbacks/CallbackManager"; +import { ALL_AVAILABLE_OPENAI_MODELS, OpenAI } from "./LLM"; +import { SimplePrompt } from "../Prompt"; +import { CallbackManager, Event } from "../callbacks/CallbackManager"; /** * LLM Predictors are an abstraction to predict the response to a prompt. @@ -18,7 +18,7 @@ export interface BaseLLMPredictor { * ChatGPTLLMPredictor is a predictor that uses GPT. */ export class ChatGPTLLMPredictor implements BaseLLMPredictor { - model: keyof typeof ALL_AVAILABLE_MODELS; + model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS; retryOnThrottling: boolean; languageModel: OpenAI; callbackManager?: CallbackManager; diff --git a/packages/core/src/fetchAdapter.d.ts b/packages/core/src/llm/fetchAdapter.d.ts similarity index 100% rename from packages/core/src/fetchAdapter.d.ts rename to packages/core/src/llm/fetchAdapter.d.ts diff --git a/packages/core/src/fetchAdapter.js b/packages/core/src/llm/fetchAdapter.js similarity index 100% rename from packages/core/src/fetchAdapter.js rename to packages/core/src/llm/fetchAdapter.js diff --git a/packages/core/src/openai.ts b/packages/core/src/llm/openai.ts similarity index 100% rename from packages/core/src/openai.ts rename to packages/core/src/llm/openai.ts diff --git a/packages/core/src/llm/replicate.ts b/packages/core/src/llm/replicate.ts new file mode 100644 index 0000000000000000000000000000000000000000..18bb4b6c4bf988cd96f7be05cf43a9c977940463 --- /dev/null +++ b/packages/core/src/llm/replicate.ts @@ -0,0 +1,32 @@ +import Replicate from "replicate"; + +export class ReplicateSession { + replicateKey: string | null = null; + replicate: Replicate; + + constructor(replicateKey: string | null = null) { + if (replicateKey) { + this.replicateKey = replicateKey; + } else if (process.env.REPLICATE_API_TOKEN) { + this.replicateKey = process.env.REPLICATE_API_TOKEN; + } else { + throw new Error( + "Set Replicate token in REPLICATE_API_TOKEN env variable" + ); + } + + this.replicate = new Replicate({ auth: this.replicateKey }); + } +} + +let defaultReplicateSession: ReplicateSession | null = null; + +export function getReplicateSession(replicateKey: string | null = null) { + if (!defaultReplicateSession) { + defaultReplicateSession = new ReplicateSession(replicateKey); + } + + return defaultReplicateSession; +} + +export * from "openai"; diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index b0dad0b54472c6e2ab4e5799514e8daf7e7109bf..6471067bfc21472e26666ae125b3a4d9c9581ad8 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -1,6 +1,6 @@ import { VectorStoreIndex } from "../indices/vectorStore/VectorStoreIndex"; import { OpenAIEmbedding } from "../Embedding"; -import { OpenAI } from "../LLM"; +import { OpenAI } from "../llm/LLM"; import { Document } from "../Node"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { @@ -16,7 +16,7 @@ import { import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI"; // Mock the OpenAI getOpenAISession function during testing -jest.mock("../openai", () => { +jest.mock("../llm/openai", () => { return { getOpenAISession: jest.fn().mockImplementation(() => null), }; diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index c80617862d3a80d1cc6cb6c1a28f2b69500ded3e..6198caab700c84f3bab58742a225c2cc3f3d9f6c 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -1,6 +1,6 @@ import { OpenAIEmbedding } from "../../Embedding"; import { globalsHelper } from "../../GlobalsHelper"; -import { ChatMessage, OpenAI } from "../../LLM"; +import { ChatMessage, OpenAI } from "../../llm/LLM"; import { CallbackManager, Event } from "../../callbacks/CallbackManager"; export function mockLlmGeneration({ diff --git a/packages/eslint-config-custom/index.js b/packages/eslint-config-custom/index.js index c2eed7e73b65222eb8cd23b8d74935666899498e..b38436a656616122e88db03ada93393b5fe10fd7 100644 --- a/packages/eslint-config-custom/index.js +++ b/packages/eslint-config-custom/index.js @@ -5,7 +5,7 @@ module.exports = { "turbo/no-undeclared-env-vars": [ "error", { - allowList: ["OPENAI_API_KEY"], + allowList: ["OPENAI_API_KEY", "REPLICATE_API_TOKEN"], }, ], }, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2008a25163a5d2e82fd2f6891e0277361d38dfa3..54ae38e317797a3f0153d062f728be1e6ef3261a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -121,6 +121,9 @@ importers: pdf-parse: specifier: ^1.1.1 version: 1.1.1 + replicate: + specifier: ^0.12.3 + version: 0.12.3 tiktoken-node: specifier: ^0.0.6 version: 0.0.6 @@ -11041,6 +11044,11 @@ packages: engines: {node: '>=0.10'} dev: false + /replicate@0.12.3: + resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==} + engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'} + dev: false + /require-directory@2.1.1: resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} engines: {node: '>=0.10.0'}