import ciInfo from "ci-info"; import ollama, { type ModelResponse } from "ollama"; import { red } from "picocolors"; import prompts from "prompts"; import { ModelConfigParams } from "."; import { questionHandlers, toChoice } from "../../questions"; type ModelData = { dimensions: number; }; const MODELS = ["llama3:8b", "wizardlm2:7b", "gemma:7b", "phi3"]; const DEFAULT_MODEL = MODELS[0]; // TODO: get embedding vector dimensions from the ollama sdk (currently not supported) const EMBEDDING_MODELS: Record<string, ModelData> = { "nomic-embed-text": { dimensions: 768 }, "mxbai-embed-large": { dimensions: 1024 }, "all-minilm": { dimensions: 384 }, }; const DEFAULT_EMBEDDING_MODEL: string = Object.keys(EMBEDDING_MODELS)[0]; type OllamaQuestionsParams = { askModels: boolean; }; export async function askOllamaQuestions({ askModels, }: OllamaQuestionsParams): Promise<ModelConfigParams> { const config: ModelConfigParams = { model: DEFAULT_MODEL, embeddingModel: DEFAULT_EMBEDDING_MODEL, dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions, }; // use default model values in CI or if user should not be asked const useDefaults = ciInfo.isCI || !askModels; if (!useDefaults) { const { model } = await prompts( { type: "select", name: "model", message: "Which LLM model would you like to use?", choices: MODELS.map(toChoice), initial: 0, }, questionHandlers, ); await ensureModel(model); config.model = model; const { embeddingModel } = await prompts( { type: "select", name: "embeddingModel", message: "Which embedding model would you like to use?", choices: Object.keys(EMBEDDING_MODELS).map(toChoice), initial: 0, }, questionHandlers, ); await ensureModel(embeddingModel); config.embeddingModel = embeddingModel; config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions; } return config; } async function ensureModel(modelName: string) { try { if (modelName.split(":").length === 1) { // model doesn't have a version suffix, use latest modelName = modelName + ":latest"; } const { models } = await ollama.list(); const found = models.find((model: ModelResponse) => model.name === modelName) !== undefined; if (!found) { console.log( red( `Model ${modelName} was not pulled yet. Call 'ollama pull ${modelName}' and try again.`, ), ); process.exit(1); } } catch (error) { console.log( red("Listing Ollama models failed. Is 'ollama' running? " + error), ); process.exit(1); } }