Skip to content
Snippets Groups Projects
Unverified Commit 7bd3ed55 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: support anthropic and gemini model providers and update to LITS 0.3.3 (#63)

parent c981eb14
No related branches found
No related tags found
No related merge requests found
Showing
with 419 additions and 88 deletions
---
"create-llama": patch
---
Support Anthropic and Gemini as model providers
---
"create-llama": patch
---
Support new agents from LITS 0.3
......@@ -173,6 +173,24 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "anthropic"
? [
{
name: "ANTHROPIC_API_KEY",
description: "The Anthropic API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "gemini"
? [
{
name: "GOOGLE_API_KEY",
description: "The Google API key to use.",
value: modelConfig.apiKey,
},
]
: []),
];
};
......
......@@ -9,7 +9,6 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
import { PackageManager } from "./get-pkg-manager";
import { installLlamapackProject } from "./llama-pack";
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
import { isModelConfigured } from "./providers";
import { installPythonTemplate } from "./python";
import { downloadAndExtractRepo } from "./repo";
import { ConfigFileType, writeToolsConfig } from "./tools";
......@@ -38,7 +37,7 @@ async function generateContextData(
? "poetry run generate"
: `${packageManager} run generate`,
)}`;
const modelConfigured = isModelConfigured(modelConfig);
const modelConfigured = modelConfig.isConfigured();
const llamaCloudKeyConfigured = useLlamaParse
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
: true;
......
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = [
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
"claude-2.1",
"claude-instant-1.2",
];
const DEFAULT_MODEL = MODELS[0];
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
// Use huggingface embedding models for now
enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
}
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
dimensions: 384,
},
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
dimensions: 768,
},
};
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type AnthropicQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askAnthropicQuestions({
askModels,
apiKey,
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["ANTHROPIC_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions =
EMBEDDING_MODELS[
embeddingModel as HuggingFaceEmbeddingModelType
].dimensions;
}
return config;
}
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<string, ModelData> = {
"embedding-001": { dimensions: 768 },
"text-embedding-004": { dimensions: 768 },
};
const DEFAULT_MODEL = MODELS[0];
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type GeminiQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askGeminiQuestions({
askModels,
apiKey,
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["GOOGLE_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.GOOGLE_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
}
return config;
}
......@@ -2,8 +2,10 @@ import ciInfo from "ci-info";
import prompts from "prompts";
import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider } from "../types";
import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini";
import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
import { askOpenAIQuestions } from "./openai";
const DEFAULT_MODEL_PROVIDER = "openai";
......@@ -31,6 +33,8 @@ export async function askModelConfig({
value: "openai",
},
{ title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" },
],
initial: 0,
},
......@@ -44,6 +48,12 @@ export async function askModelConfig({
case "ollama":
modelConfig = await askOllamaQuestions({ askModels });
break;
case "anthropic":
modelConfig = await askAnthropicQuestions({ askModels });
break;
case "gemini":
modelConfig = await askGeminiQuestions({ askModels });
break;
default:
modelConfig = await askOpenAIQuestions({
openAiKey,
......@@ -55,12 +65,3 @@ export async function askModelConfig({
provider: modelProvider,
};
}
export function isModelConfigured(modelConfig: ModelConfig): boolean {
switch (modelConfig.provider) {
case "openai":
return isOpenAIConfigured(modelConfig);
default:
return true;
}
}
......@@ -29,6 +29,9 @@ export async function askOllamaQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
isConfigured(): boolean {
return true;
},
};
// use default model values in CI or if user should not be asked
......
......@@ -20,6 +20,15 @@ export async function askOpenAIQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
......@@ -31,7 +40,6 @@ export async function askOpenAIQuestions({
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
: "Please provide your OpenAI API key (leave blank to skip):",
validate: (value: string) => {
console.log(value);
if (askModels && !value) {
if (process.env.OPENAI_API_KEY) {
return true;
......@@ -78,16 +86,6 @@ export async function askOpenAIQuestions({
return config;
}
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
if (params.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
}
async function getAvailableModelChoices(
selectEmbedding: boolean,
apiKey?: string,
......
......@@ -127,6 +127,26 @@ const getAdditionalDependencies = (
version: "0.2.2",
});
break;
case "anthropic":
dependencies.push({
name: "llama-index-llms-anthropic",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-huggingface",
version: "0.2.0",
});
break;
case "gemini":
dependencies.push({
name: "llama-index-llms-gemini",
version: "0.1.7",
});
dependencies.push({
name: "llama-index-embeddings-gemini",
version: "0.1.6",
});
break;
}
return dependencies;
......
import { PackageManager } from "../helpers/get-pkg-manager";
import { Tool } from "./tools";
export type ModelProvider = "openai" | "ollama";
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
export type ModelConfig = {
provider: ModelProvider;
apiKey?: string;
model: string;
embeddingModel: string;
dimensions: number;
isConfigured(): boolean;
};
export type TemplateType = "streaming" | "community" | "llamapack";
export type TemplateFramework = "nextjs" | "express" | "fastapi";
......
......@@ -14,7 +14,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
import { EXAMPLE_FILE } from "./helpers/datasources";
import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { askModelConfig, isModelConfigured } from "./helpers/providers";
import { askModelConfig } from "./helpers/providers";
import { getProjectOptions } from "./helpers/repo";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
......@@ -257,7 +257,8 @@ export const askQuestions = async (
},
];
const modelConfigured = isModelConfigured(program.modelConfig);
const modelConfigured =
!program.llamapack && program.modelConfig.isConfigured();
// If using LlamaParse, require LlamaCloud API key
const llamaCloudKeyConfigured = program.useLlamaParse
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
......@@ -268,8 +269,7 @@ export const askQuestions = async (
!hasVectorDb &&
modelConfigured &&
llamaCloudKeyConfigured &&
!toolsRequireConfig(program.tools) &&
!program.llamapack
!toolsRequireConfig(program.tools)
) {
actionChoices.push({
title:
......@@ -398,7 +398,6 @@ export const askQuestions = async (
if (program.framework === "express" || program.framework === "fastapi") {
// if a backend-only framework is selected, ask whether we should create a frontend
// (only for streaming backends)
if (program.frontend === undefined) {
if (ciInfo.isCI) {
program.frontend = getPrefOrDefault("frontend");
......
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import fs from "node:fs/promises";
import path from "node:path";
......@@ -6,7 +6,7 @@ import { getDataSource } from "./index";
import { STORAGE_CACHE_DIR } from "./shared";
export async function createChatEngine() {
let tools: BaseTool[] = [];
let tools: BaseToolWithCall[] = [];
// Add a query engine tool if we have a data source
// Delete this code if you don't have a data source
......
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import * as dotenv from "dotenv";
......
import {
SimpleDocumentStore,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import { STORAGE_CACHE_DIR } from "./shared";
export async function getDataSource() {
......
......@@ -14,7 +14,8 @@
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "0.2.10"
"llamaindex": "0.3.3",
"pdf2json": "3.0.5"
},
"devDependencies": {
"@types/cors": "^2.8.16",
......
import { Message, StreamData, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import {
CallbackManager,
ChatMessage,
MessageContent,
Settings,
} from "llamaindex";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
......@@ -45,14 +50,15 @@ export const chat = async (req: Request, res: Response) => {
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
// Setup callbacks
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
......@@ -60,31 +66,23 @@ export const chat = async (req: Request, res: Response) => {
});
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
});
// Return a stream, which can be consumed by the Vercel/AI client
const { stream } = LlamaIndexStream(response, vercelStreamData, {
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
// Pipe LlamaIndexStream to response
const processedStream = stream.pipeThrough(vercelStreamData.stream);
return streamToResponse(processedStream, res, {
headers: {
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
// so that the client uses the correct parsing logic, see
// https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
},
});
return streamToResponse(processedStream, res);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
......
import {
Ollama,
OllamaEmbedding,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
......@@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
......@@ -38,11 +56,6 @@ function initOpenAI() {
}
function initOllama() {
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
});
......@@ -50,3 +63,25 @@ function initOllama() {
model: process.env.EMBEDDING_MODEL ?? "",
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
......@@ -9,16 +9,22 @@ import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<Response>,
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
......@@ -33,17 +39,27 @@ function createParser(
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
controller.close();
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
}
const text = trimStartOfStream(value.response ?? "");
const text = trimStartOfStream(delta ?? "");
if (text) {
controller.enqueue(text);
}
......@@ -52,21 +68,14 @@ function createParser(
}
export function LlamaIndexStream(
response: StreamingAgentChatResponse | AsyncIterable<Response>,
response: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
): ReadableStream<string> {
return createParser(response, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
......@@ -9,6 +9,10 @@ def init_settings():
init_openai()
elif model_provider == "ollama":
init_ollama()
elif model_provider == "anthropic":
init_anthropic()
elif model_provider == "gemini":
init_gemini()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
......@@ -42,3 +46,47 @@ def init_openai():
"dimension": int(dimension) if dimension is not None else None,
}
Settings.embed_model = OpenAIEmbedding(**config)
def init_anthropic():
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
model_map: Dict[str, str] = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-2.1": "claude-2.1",
"claude-instant-1.2": "claude-instant-1.2",
}
embed_model_map: Dict[str, str] = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
}
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
Settings.embed_model = HuggingFaceEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
def init_gemini():
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
model_map: Dict[str, str] = {
"gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest",
"gemini-pro": "models/gemini-pro",
"gemini-pro-vision": "models/gemini-pro-vision",
}
embed_model_map: Dict[str, str] = {
"embedding-001": "models/embedding-001",
"text-embedding-004": "models/text-embedding-004",
}
Settings.llm = Gemini(model=model_map[os.getenv("MODEL")])
Settings.embed_model = GeminiEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment