diff --git a/.changeset/lovely-papayas-exist.md b/.changeset/lovely-papayas-exist.md new file mode 100644 index 0000000000000000000000000000000000000000..ce01fa9bd1b6f8835b6b22b3412aee0cfdfe61a0 --- /dev/null +++ b/.changeset/lovely-papayas-exist.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Support Anthropic and Gemini as model providers diff --git a/.changeset/ninety-clocks-ring.md b/.changeset/ninety-clocks-ring.md new file mode 100644 index 0000000000000000000000000000000000000000..cac6cb886f0a41fabc50fffd5bd5f6a2cd1840a3 --- /dev/null +++ b/.changeset/ninety-clocks-ring.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Support new agents from LITS 0.3 diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 5451d435880aa5d5bf8863d56271461f98c6d253..5a91da6b582fd925c4a15dea8db1057f3cbf4616 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -173,6 +173,24 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { }, ] : []), + ...(modelConfig.provider === "anthropic" + ? [ + { + name: "ANTHROPIC_API_KEY", + description: "The Anthropic API key to use.", + value: modelConfig.apiKey, + }, + ] + : []), + ...(modelConfig.provider === "gemini" + ? [ + { + name: "GOOGLE_API_KEY", + description: "The Google API key to use.", + value: modelConfig.apiKey, + }, + ] + : []), ]; }; diff --git a/helpers/index.ts b/helpers/index.ts index b7b3991df77413b4755c58cdc349cfae469792e5..571182f186a73396faccf1dc59dbcb1cffa730cb 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -9,7 +9,6 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables"; import { PackageManager } from "./get-pkg-manager"; import { installLlamapackProject } from "./llama-pack"; import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry"; -import { isModelConfigured } from "./providers"; import { installPythonTemplate } from "./python"; import { downloadAndExtractRepo } from "./repo"; import { ConfigFileType, writeToolsConfig } from "./tools"; @@ -38,7 +37,7 @@ async function generateContextData( ? "poetry run generate" : `${packageManager} run generate`, )}`; - const modelConfigured = isModelConfigured(modelConfig); + const modelConfigured = modelConfig.isConfigured(); const llamaCloudKeyConfigured = useLlamaParse ? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"] : true; diff --git a/helpers/providers/anthropic.ts b/helpers/providers/anthropic.ts new file mode 100644 index 0000000000000000000000000000000000000000..1239a0c4954975a077b136949e8f4c4371202fc2 --- /dev/null +++ b/helpers/providers/anthropic.ts @@ -0,0 +1,106 @@ +import ciInfo from "ci-info"; +import prompts from "prompts"; +import { ModelConfigParams } from "."; +import { questionHandlers, toChoice } from "../../questions"; + +const MODELS = [ + "claude-3-opus", + "claude-3-sonnet", + "claude-3-haiku", + "claude-2.1", + "claude-instant-1.2", +]; +const DEFAULT_MODEL = MODELS[0]; + +// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported) +// Use huggingface embedding models for now +enum HuggingFaceEmbeddingModelType { + XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2", + XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2", +} +type ModelData = { + dimensions: number; +}; +const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = { + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: { + dimensions: 384, + }, + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: { + dimensions: 768, + }, +}; +const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0]; +const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions; + +type AnthropicQuestionsParams = { + apiKey?: string; + askModels: boolean; +}; + +export async function askAnthropicQuestions({ + askModels, + apiKey, +}: AnthropicQuestionsParams): Promise<ModelConfigParams> { + const config: ModelConfigParams = { + apiKey, + model: DEFAULT_MODEL, + embeddingModel: DEFAULT_EMBEDDING_MODEL, + dimensions: DEFAULT_DIMENSIONS, + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["ANTHROPIC_API_KEY"]) { + return true; + } + return false; + }, + }; + + if (!config.apiKey) { + const { key } = await prompts( + { + type: "text", + name: "key", + message: + "Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):", + }, + questionHandlers, + ); + config.apiKey = key || process.env.ANTHROPIC_API_KEY; + } + + // use default model values in CI or if user should not be asked + const useDefaults = ciInfo.isCI || !askModels; + if (!useDefaults) { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which LLM model would you like to use?", + choices: MODELS.map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.model = model; + + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: Object.keys(EMBEDDING_MODELS).map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.embeddingModel = embeddingModel; + config.dimensions = + EMBEDDING_MODELS[ + embeddingModel as HuggingFaceEmbeddingModelType + ].dimensions; + } + + return config; +} diff --git a/helpers/providers/gemini.ts b/helpers/providers/gemini.ts new file mode 100644 index 0000000000000000000000000000000000000000..b0f6733f69d296b767af3f496a7686fb15e8a4de --- /dev/null +++ b/helpers/providers/gemini.ts @@ -0,0 +1,87 @@ +import ciInfo from "ci-info"; +import prompts from "prompts"; +import { ModelConfigParams } from "."; +import { questionHandlers, toChoice } from "../../questions"; + +const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"]; +type ModelData = { + dimensions: number; +}; +const EMBEDDING_MODELS: Record<string, ModelData> = { + "embedding-001": { dimensions: 768 }, + "text-embedding-004": { dimensions: 768 }, +}; + +const DEFAULT_MODEL = MODELS[0]; +const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0]; +const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions; + +type GeminiQuestionsParams = { + apiKey?: string; + askModels: boolean; +}; + +export async function askGeminiQuestions({ + askModels, + apiKey, +}: GeminiQuestionsParams): Promise<ModelConfigParams> { + const config: ModelConfigParams = { + apiKey, + model: DEFAULT_MODEL, + embeddingModel: DEFAULT_EMBEDDING_MODEL, + dimensions: DEFAULT_DIMENSIONS, + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["GOOGLE_API_KEY"]) { + return true; + } + return false; + }, + }; + + if (!config.apiKey) { + const { key } = await prompts( + { + type: "text", + name: "key", + message: + "Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):", + }, + questionHandlers, + ); + config.apiKey = key || process.env.GOOGLE_API_KEY; + } + + // use default model values in CI or if user should not be asked + const useDefaults = ciInfo.isCI || !askModels; + if (!useDefaults) { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which LLM model would you like to use?", + choices: MODELS.map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.model = model; + + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: Object.keys(EMBEDDING_MODELS).map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.embeddingModel = embeddingModel; + config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions; + } + + return config; +} diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts index bb793c090675ebcda172e3ebed1c19ff35a70c11..62202c0f6f9f4aed96de17668f9760ddc260cf15 100644 --- a/helpers/providers/index.ts +++ b/helpers/providers/index.ts @@ -2,8 +2,10 @@ import ciInfo from "ci-info"; import prompts from "prompts"; import { questionHandlers } from "../../questions"; import { ModelConfig, ModelProvider } from "../types"; +import { askAnthropicQuestions } from "./anthropic"; +import { askGeminiQuestions } from "./gemini"; import { askOllamaQuestions } from "./ollama"; -import { askOpenAIQuestions, isOpenAIConfigured } from "./openai"; +import { askOpenAIQuestions } from "./openai"; const DEFAULT_MODEL_PROVIDER = "openai"; @@ -31,6 +33,8 @@ export async function askModelConfig({ value: "openai", }, { title: "Ollama", value: "ollama" }, + { title: "Anthropic", value: "anthropic" }, + { title: "Gemini", value: "gemini" }, ], initial: 0, }, @@ -44,6 +48,12 @@ export async function askModelConfig({ case "ollama": modelConfig = await askOllamaQuestions({ askModels }); break; + case "anthropic": + modelConfig = await askAnthropicQuestions({ askModels }); + break; + case "gemini": + modelConfig = await askGeminiQuestions({ askModels }); + break; default: modelConfig = await askOpenAIQuestions({ openAiKey, @@ -55,12 +65,3 @@ export async function askModelConfig({ provider: modelProvider, }; } - -export function isModelConfigured(modelConfig: ModelConfig): boolean { - switch (modelConfig.provider) { - case "openai": - return isOpenAIConfigured(modelConfig); - default: - return true; - } -} diff --git a/helpers/providers/ollama.ts b/helpers/providers/ollama.ts index a32d9d309034a60587f54791856c7998745d53a6..e70b25f06336bb45911e4878d1328cd00dfc08eb 100644 --- a/helpers/providers/ollama.ts +++ b/helpers/providers/ollama.ts @@ -29,6 +29,9 @@ export async function askOllamaQuestions({ model: DEFAULT_MODEL, embeddingModel: DEFAULT_EMBEDDING_MODEL, dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions, + isConfigured(): boolean { + return true; + }, }; // use default model values in CI or if user should not be asked diff --git a/helpers/providers/openai.ts b/helpers/providers/openai.ts index 2e13d99f0b27e4d4b8d9ff55058a3ef2532a18ba..c02fee36cdbb0afa5e90ad26d7bdf28894f45e6b 100644 --- a/helpers/providers/openai.ts +++ b/helpers/providers/openai.ts @@ -20,6 +20,15 @@ export async function askOpenAIQuestions({ model: DEFAULT_MODEL, embeddingModel: DEFAULT_EMBEDDING_MODEL, dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL), + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["OPENAI_API_KEY"]) { + return true; + } + return false; + }, }; if (!config.apiKey) { @@ -31,7 +40,6 @@ export async function askOpenAIQuestions({ ? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):" : "Please provide your OpenAI API key (leave blank to skip):", validate: (value: string) => { - console.log(value); if (askModels && !value) { if (process.env.OPENAI_API_KEY) { return true; @@ -78,16 +86,6 @@ export async function askOpenAIQuestions({ return config; } -export function isOpenAIConfigured(params: ModelConfigParams): boolean { - if (params.apiKey) { - return true; - } - if (process.env["OPENAI_API_KEY"]) { - return true; - } - return false; -} - async function getAvailableModelChoices( selectEmbedding: boolean, apiKey?: string, diff --git a/helpers/python.ts b/helpers/python.ts index 316fe52d51b3b821f20e2992b7d80a2e869f529c..8c2d7bce6b6a9494df9bffc832d56bbd038826e9 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -127,6 +127,26 @@ const getAdditionalDependencies = ( version: "0.2.2", }); break; + case "anthropic": + dependencies.push({ + name: "llama-index-llms-anthropic", + version: "0.1.10", + }); + dependencies.push({ + name: "llama-index-embeddings-huggingface", + version: "0.2.0", + }); + break; + case "gemini": + dependencies.push({ + name: "llama-index-llms-gemini", + version: "0.1.7", + }); + dependencies.push({ + name: "llama-index-embeddings-gemini", + version: "0.1.6", + }); + break; } return dependencies; diff --git a/helpers/types.ts b/helpers/types.ts index 42b571e7f16f587dd6fdf7348d24bd7277f04b5b..b70f586aac35ac7ac2206cbd2a0a593927b60b6c 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -1,13 +1,14 @@ import { PackageManager } from "../helpers/get-pkg-manager"; import { Tool } from "./tools"; -export type ModelProvider = "openai" | "ollama"; +export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini"; export type ModelConfig = { provider: ModelProvider; apiKey?: string; model: string; embeddingModel: string; dimensions: number; + isConfigured(): boolean; }; export type TemplateType = "streaming" | "community" | "llamapack"; export type TemplateFramework = "nextjs" | "express" | "fastapi"; diff --git a/questions.ts b/questions.ts index 31eed224772df4ccb6180ba80db824a5523d1e4f..3e0460847242bcf08a9d3066612eb69c1ec92d02 100644 --- a/questions.ts +++ b/questions.ts @@ -14,7 +14,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant"; import { EXAMPLE_FILE } from "./helpers/datasources"; import { templatesDir } from "./helpers/dir"; import { getAvailableLlamapackOptions } from "./helpers/llama-pack"; -import { askModelConfig, isModelConfigured } from "./helpers/providers"; +import { askModelConfig } from "./helpers/providers"; import { getProjectOptions } from "./helpers/repo"; import { supportedTools, toolsRequireConfig } from "./helpers/tools"; @@ -257,7 +257,8 @@ export const askQuestions = async ( }, ]; - const modelConfigured = isModelConfigured(program.modelConfig); + const modelConfigured = + !program.llamapack && program.modelConfig.isConfigured(); // If using LlamaParse, require LlamaCloud API key const llamaCloudKeyConfigured = program.useLlamaParse ? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"] @@ -268,8 +269,7 @@ export const askQuestions = async ( !hasVectorDb && modelConfigured && llamaCloudKeyConfigured && - !toolsRequireConfig(program.tools) && - !program.llamapack + !toolsRequireConfig(program.tools) ) { actionChoices.push({ title: @@ -398,7 +398,6 @@ export const askQuestions = async ( if (program.framework === "express" || program.framework === "fastapi") { // if a backend-only framework is selected, ask whether we should create a frontend - // (only for streaming backends) if (program.frontend === undefined) { if (ciInfo.isCI) { program.frontend = getPrefOrDefault("frontend"); diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts index 8fdbe272c7616f09b9bb054ed2367a91cf716b58..41d7118d915ee331138ea663ecd2f98d6a590931 100644 --- a/templates/components/engines/typescript/agent/chat.ts +++ b/templates/components/engines/typescript/agent/chat.ts @@ -1,4 +1,4 @@ -import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex"; +import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; import fs from "node:fs/promises"; import path from "node:path"; @@ -6,7 +6,7 @@ import { getDataSource } from "./index"; import { STORAGE_CACHE_DIR } from "./shared"; export async function createChatEngine() { - let tools: BaseTool[] = []; + let tools: BaseToolWithCall[] = []; // Add a query engine tool if we have a data source // Delete this code if you don't have a data source diff --git a/templates/components/vectordbs/typescript/none/generate.ts b/templates/components/vectordbs/typescript/none/generate.ts index 732ba21129c8aab1dc7de78df2e03df54092a791..8c162805b132f056106f52e20b567af1a7985aa5 100644 --- a/templates/components/vectordbs/typescript/none/generate.ts +++ b/templates/components/vectordbs/typescript/none/generate.ts @@ -1,4 +1,5 @@ -import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex"; +import { VectorStoreIndex } from "llamaindex"; +import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; import * as dotenv from "dotenv"; diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index 919ba3a615b454dbcc9230f00dcff57e69fd33ec..64b289750cd978e88b4b17d8194bb468026e8910 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -1,8 +1,5 @@ -import { - SimpleDocumentStore, - storageContextFromDefaults, - VectorStoreIndex, -} from "llamaindex"; +import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex"; +import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; import { STORAGE_CACHE_DIR } from "./shared"; export async function getDataSource() { diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index 612076e5415ec558a8af7f1c05135d1265060da3..d59f18721f8a7e5cef5ff8fdaa1dfc1bab43f77e 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -14,7 +14,8 @@ "cors": "^2.8.5", "dotenv": "^16.3.1", "express": "^4.18.2", - "llamaindex": "0.2.10" + "llamaindex": "0.3.3", + "pdf2json": "3.0.5" }, "devDependencies": { "@types/cors": "^2.8.16", diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index bb77c6e5c46733bf4936c13e3e24f7b83aa3ba55..d9ba57dcab0ffcee1ed80a33ca9ecaa95e2c50c9 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,6 +1,11 @@ import { Message, StreamData, streamToResponse } from "ai"; import { Request, Response } from "express"; -import { ChatMessage, MessageContent, Settings } from "llamaindex"; +import { + CallbackManager, + ChatMessage, + MessageContent, + Settings, +} from "llamaindex"; import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; import { appendEventData } from "./stream-helper"; @@ -45,14 +50,15 @@ export const chat = async (req: Request, res: Response) => { // Init Vercel AI StreamData const vercelStreamData = new StreamData(); - appendEventData( - vercelStreamData, - `Retrieving context for query: '${userMessage.content}'`, - ); - // Setup callback for streaming data before chatting - Settings.callbackManager.on("retrieve", (data) => { + // Setup callbacks + const callbackManager = new CallbackManager(); + callbackManager.on("retrieve", (data) => { const { nodes } = data.detail; + appendEventData( + vercelStreamData, + `Retrieving context for query: '${userMessage.content}'`, + ); appendEventData( vercelStreamData, `Retrieved ${nodes.length} sources to use as context for the query`, @@ -60,31 +66,23 @@ export const chat = async (req: Request, res: Response) => { }); // Calling LlamaIndex's ChatEngine to get a streamed response - const response = await chatEngine.chat({ - message: userMessageContent, - chatHistory: messages as ChatMessage[], - stream: true, + const response = await Settings.withCallbackManager(callbackManager, () => { + return chatEngine.chat({ + message: userMessageContent, + chatHistory: messages as ChatMessage[], + stream: true, + }); }); // Return a stream, which can be consumed by the Vercel/AI client - const { stream } = LlamaIndexStream(response, vercelStreamData, { + const stream = LlamaIndexStream(response, vercelStreamData, { parserOptions: { image_url: data?.imageUrl, }, }); - - // Pipe LlamaIndexStream to response const processedStream = stream.pipeThrough(vercelStreamData.stream); - return streamToResponse(processedStream, res, { - headers: { - // response MUST have the `X-Experimental-Stream-Data: 'true'` header - // so that the client uses the correct parsing logic, see - // https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server - "X-Experimental-Stream-Data": "true", - "Content-Type": "text/plain; charset=utf-8", - "Access-Control-Expose-Headers": "X-Experimental-Stream-Data", - }, - }); + + return streamToResponse(processedStream, res); } catch (error) { console.error("[LlamaIndex]", error); return res.status(500).json({ diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index 258043993aa14f618c62a4a196d00645047ada63..6e980afbb6137f65d7e3c6cf6b2720a11897024c 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -1,10 +1,17 @@ import { - Ollama, - OllamaEmbedding, + Anthropic, + GEMINI_EMBEDDING_MODEL, + GEMINI_MODEL, + Gemini, + GeminiEmbedding, OpenAI, OpenAIEmbedding, Settings, } from "llamaindex"; +import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding"; +import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; +import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic"; +import { Ollama } from "llamaindex/llm/ollama"; const CHUNK_SIZE = 512; const CHUNK_OVERLAP = 20; @@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20; export const initSettings = async () => { // HINT: you can delete the initialization code for unused model providers console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`); + + if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { + throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set."); + } + switch (process.env.MODEL_PROVIDER) { case "ollama": initOllama(); break; + case "anthropic": + initAnthropic(); + break; + case "gemini": + initGemini(); + break; default: initOpenAI(); break; @@ -38,11 +56,6 @@ function initOpenAI() { } function initOllama() { - if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { - throw new Error( - "Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.", - ); - } Settings.llm = new Ollama({ model: process.env.MODEL ?? "", }); @@ -50,3 +63,25 @@ function initOllama() { model: process.env.EMBEDDING_MODEL ?? "", }); } + +function initAnthropic() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + Settings.llm = new Anthropic({ + model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS, + }); + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + +function initGemini() { + Settings.llm = new Gemini({ + model: process.env.MODEL as GEMINI_MODEL, + }); + Settings.embedModel = new GeminiEmbedding({ + model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, + }); +} diff --git a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts index b689905727767fb1f17f099298d13e72f0234df7..df2c17ffc34744d044ecdcc375955dbc7ea5ba70 100644 --- a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts +++ b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts @@ -9,16 +9,22 @@ import { Metadata, NodeWithScore, Response, - StreamingAgentChatResponse, + ToolCallLLMMessageOptions, } from "llamaindex"; + +import { AgentStreamChatResponse } from "llamaindex/agent/base"; import { appendImageData, appendSourceData } from "./stream-helper"; +type LlamaIndexResponse = + | AgentStreamChatResponse<ToolCallLLMMessageOptions> + | Response; + type ParserOptions = { image_url?: string; }; function createParser( - res: AsyncIterable<Response>, + res: AsyncIterable<LlamaIndexResponse>, data: StreamData, opts?: ParserOptions, ) { @@ -33,17 +39,27 @@ function createParser( async pull(controller): Promise<void> { const { value, done } = await it.next(); if (done) { - appendSourceData(data, sourceNodes); + if (sourceNodes) { + appendSourceData(data, sourceNodes); + } controller.close(); data.close(); return; } - if (!sourceNodes) { - // get source nodes from the first response - sourceNodes = value.sourceNodes; + let delta; + if (value instanceof Response) { + // handle Response type + if (value.sourceNodes) { + // get source nodes from the first response + sourceNodes = value.sourceNodes; + } + delta = value.response ?? ""; + } else { + // handle other types + delta = value.response.delta; } - const text = trimStartOfStream(value.response ?? ""); + const text = trimStartOfStream(delta ?? ""); if (text) { controller.enqueue(text); } @@ -52,21 +68,14 @@ function createParser( } export function LlamaIndexStream( - response: StreamingAgentChatResponse | AsyncIterable<Response>, + response: AsyncIterable<LlamaIndexResponse>, data: StreamData, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, -): { stream: ReadableStream; data: StreamData } { - const res = - response instanceof StreamingAgentChatResponse - ? response.response - : response; - return { - stream: createParser(res, data, opts?.parserOptions) - .pipeThrough(createCallbacksTransformer(opts?.callbacks)) - .pipeThrough(createStreamDataTransformer()), - data, - }; +): ReadableStream<string> { + return createParser(response, data, opts?.parserOptions) + .pipeThrough(createCallbacksTransformer(opts?.callbacks)) + .pipeThrough(createStreamDataTransformer()); } diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py index be272d548c8bcbe2e960b3b476e97c836ee7d931..41ab539584d0201444f609477d620696e793b998 100644 --- a/templates/types/streaming/fastapi/app/settings.py +++ b/templates/types/streaming/fastapi/app/settings.py @@ -9,6 +9,10 @@ def init_settings(): init_openai() elif model_provider == "ollama": init_ollama() + elif model_provider == "anthropic": + init_anthropic() + elif model_provider == "gemini": + init_gemini() else: raise ValueError(f"Invalid model provider: {model_provider}") Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024")) @@ -42,3 +46,47 @@ def init_openai(): "dimension": int(dimension) if dimension is not None else None, } Settings.embed_model = OpenAIEmbedding(**config) + + +def init_anthropic(): + from llama_index.llms.anthropic import Anthropic + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + + model_map: Dict[str, str] = { + "claude-3-opus": "claude-3-opus-20240229", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-2.1": "claude-2.1", + "claude-instant-1.2": "claude-instant-1.2", + } + + embed_model_map: Dict[str, str] = { + "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2", + } + + Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")]) + Settings.embed_model = HuggingFaceEmbedding( + model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + ) + + +def init_gemini(): + from llama_index.llms.gemini import Gemini + from llama_index.embeddings.gemini import GeminiEmbedding + + model_map: Dict[str, str] = { + "gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest", + "gemini-pro": "models/gemini-pro", + "gemini-pro-vision": "models/gemini-pro-vision", + } + + embed_model_map: Dict[str, str] = { + "embedding-001": "models/embedding-001", + "text-embedding-004": "models/text-embedding-004", + } + + Settings.llm = Gemini(model=model_map[os.getenv("MODEL")]) + Settings.embed_model = GeminiEmbedding( + model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + ) diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index 258043993aa14f618c62a4a196d00645047ada63..6e980afbb6137f65d7e3c6cf6b2720a11897024c 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -1,10 +1,17 @@ import { - Ollama, - OllamaEmbedding, + Anthropic, + GEMINI_EMBEDDING_MODEL, + GEMINI_MODEL, + Gemini, + GeminiEmbedding, OpenAI, OpenAIEmbedding, Settings, } from "llamaindex"; +import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding"; +import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; +import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic"; +import { Ollama } from "llamaindex/llm/ollama"; const CHUNK_SIZE = 512; const CHUNK_OVERLAP = 20; @@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20; export const initSettings = async () => { // HINT: you can delete the initialization code for unused model providers console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`); + + if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { + throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set."); + } + switch (process.env.MODEL_PROVIDER) { case "ollama": initOllama(); break; + case "anthropic": + initAnthropic(); + break; + case "gemini": + initGemini(); + break; default: initOpenAI(); break; @@ -38,11 +56,6 @@ function initOpenAI() { } function initOllama() { - if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { - throw new Error( - "Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.", - ); - } Settings.llm = new Ollama({ model: process.env.MODEL ?? "", }); @@ -50,3 +63,25 @@ function initOllama() { model: process.env.EMBEDDING_MODEL ?? "", }); } + +function initAnthropic() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + Settings.llm = new Anthropic({ + model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS, + }); + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + +function initGemini() { + Settings.llm = new Gemini({ + model: process.env.MODEL as GEMINI_MODEL, + }); + Settings.embedModel = new GeminiEmbedding({ + model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, + }); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts index b689905727767fb1f17f099298d13e72f0234df7..df2c17ffc34744d044ecdcc375955dbc7ea5ba70 100644 --- a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts +++ b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts @@ -9,16 +9,22 @@ import { Metadata, NodeWithScore, Response, - StreamingAgentChatResponse, + ToolCallLLMMessageOptions, } from "llamaindex"; + +import { AgentStreamChatResponse } from "llamaindex/agent/base"; import { appendImageData, appendSourceData } from "./stream-helper"; +type LlamaIndexResponse = + | AgentStreamChatResponse<ToolCallLLMMessageOptions> + | Response; + type ParserOptions = { image_url?: string; }; function createParser( - res: AsyncIterable<Response>, + res: AsyncIterable<LlamaIndexResponse>, data: StreamData, opts?: ParserOptions, ) { @@ -33,17 +39,27 @@ function createParser( async pull(controller): Promise<void> { const { value, done } = await it.next(); if (done) { - appendSourceData(data, sourceNodes); + if (sourceNodes) { + appendSourceData(data, sourceNodes); + } controller.close(); data.close(); return; } - if (!sourceNodes) { - // get source nodes from the first response - sourceNodes = value.sourceNodes; + let delta; + if (value instanceof Response) { + // handle Response type + if (value.sourceNodes) { + // get source nodes from the first response + sourceNodes = value.sourceNodes; + } + delta = value.response ?? ""; + } else { + // handle other types + delta = value.response.delta; } - const text = trimStartOfStream(value.response ?? ""); + const text = trimStartOfStream(delta ?? ""); if (text) { controller.enqueue(text); } @@ -52,21 +68,14 @@ function createParser( } export function LlamaIndexStream( - response: StreamingAgentChatResponse | AsyncIterable<Response>, + response: AsyncIterable<LlamaIndexResponse>, data: StreamData, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, -): { stream: ReadableStream; data: StreamData } { - const res = - response instanceof StreamingAgentChatResponse - ? response.response - : response; - return { - stream: createParser(res, data, opts?.parserOptions) - .pipeThrough(createCallbacksTransformer(opts?.callbacks)) - .pipeThrough(createStreamDataTransformer()), - data, - }; +): ReadableStream<string> { + return createParser(response, data, opts?.parserOptions) + .pipeThrough(createCallbacksTransformer(opts?.callbacks)) + .pipeThrough(createStreamDataTransformer()); } diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index f7afa08399357fff5eb85a87c19090387a0e85ed..117848d37c291a3b8dece5236c29c1c926d7ba00 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -1,6 +1,11 @@ import { initObservability } from "@/app/observability"; import { Message, StreamData, StreamingTextResponse } from "ai"; -import { ChatMessage, MessageContent, Settings } from "llamaindex"; +import { + CallbackManager, + ChatMessage, + MessageContent, + Settings, +} from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine/chat"; import { initSettings } from "./engine/settings"; @@ -57,14 +62,15 @@ export async function POST(request: NextRequest) { // Init Vercel AI StreamData const vercelStreamData = new StreamData(); - appendEventData( - vercelStreamData, - `Retrieving context for query: '${userMessage.content}'`, - ); - // Setup callback for streaming data before chatting - Settings.callbackManager.on("retrieve", (data) => { + // Setup callbacks + const callbackManager = new CallbackManager(); + callbackManager.on("retrieve", (data) => { const { nodes } = data.detail; + appendEventData( + vercelStreamData, + `Retrieving context for query: '${userMessage.content}'`, + ); appendEventData( vercelStreamData, `Retrieved ${nodes.length} sources to use as context for the query`, @@ -72,14 +78,16 @@ export async function POST(request: NextRequest) { }); // Calling LlamaIndex's ChatEngine to get a streamed response - const response = await chatEngine.chat({ - message: userMessageContent, - chatHistory: messages as ChatMessage[], - stream: true, + const response = await Settings.withCallbackManager(callbackManager, () => { + return chatEngine.chat({ + message: userMessageContent, + chatHistory: messages as ChatMessage[], + stream: true, + }); }); // Transform LlamaIndex stream to Vercel/AI format - const { stream } = LlamaIndexStream(response, vercelStreamData, { + const stream = LlamaIndexStream(response, vercelStreamData, { parserOptions: { image_url: data?.imageUrl, }, diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 707bd689f9aab5688e1dd77adc06570047f68149..3d043a14d2fb748c7cf1e12613a791c70225e180 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -17,9 +17,10 @@ "class-variance-authority": "^0.7.0", "clsx": "^1.2.1", "dotenv": "^16.3.1", - "llamaindex": "0.2.10", + "llamaindex": "0.3.3", "lucide-react": "^0.294.0", "next": "^14.0.3", + "pdf2json": "3.0.5", "react": "^18.2.0", "react-dom": "^18.2.0", "react-markdown": "^8.0.7",