From 331d3741c9a675f8015c31bdf2b1cc37ff172c45 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Tue, 7 May 2024 18:06:31 -0700 Subject: [PATCH] Ollama agents (#1270) * add LMStudio agent support (generic) support "work" with non-tool callable LLMs, highly dependent on system specs * add comments * enable few-shot prompting per function for OSS models * Add Agent support for Ollama models * improve json parsing for ollama text responses --- .../AgentConfig/AgentLLMSelection/index.jsx | 4 +- server/package.json | 2 + server/utils/agents/aibitat/index.js | 2 + .../aibitat/providers/helpers/untooled.js | 62 ++++------ .../utils/agents/aibitat/providers/index.js | 2 + .../agents/aibitat/providers/lmstudio.js | 25 +++- .../utils/agents/aibitat/providers/ollama.js | 107 ++++++++++++++++++ server/utils/agents/index.js | 8 +- server/utils/http/index.js | 7 +- server/yarn.lock | 14 ++- 10 files changed, 188 insertions(+), 45 deletions(-) create mode 100644 server/utils/agents/aibitat/providers/ollama.js diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index 408d60a02..fcb12d94d 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -5,8 +5,8 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import AgentModelSelection from "../AgentModelSelection"; -const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio"]; -const WARN_PERFORMANCE = ["lmstudio"]; +const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; +const WARN_PERFORMANCE = ["lmstudio", "ollama"]; const LLM_DEFAULT = { name: "Please make a selection", diff --git a/server/package.json b/server/package.json index afe6926de..edee71b02 100644 --- a/server/package.json +++ b/server/package.json @@ -46,6 +46,7 @@ "dotenv": "^16.0.3", "express": "^4.18.2", "express-ws": "^5.0.2", + "extract-json-from-string": "^1.0.1", "extract-zip": "^2.0.1", "graphql": "^16.7.1", "joi": "^17.11.0", @@ -59,6 +60,7 @@ "multer": "^1.4.5-lts.1", "node-html-markdown": "^1.3.0", "node-llama-cpp": "^2.8.0", + "ollama": "^0.5.0", "openai": "4.38.5", "pinecone-client": "^1.1.0", "pluralize": "^8.0.0", diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index c3ad5428f..9cf2170b7 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -741,6 +741,8 @@ ${this.getHistory({ to: route.to }) return new Providers.AnthropicProvider({ model: config.model }); case "lmstudio": return new Providers.LMStudioProvider({}); + case "ollama": + return new Providers.OllamaProvider({ model: config.model }); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js index a84aad77c..37ecb5599 100644 --- a/server/utils/agents/aibitat/providers/helpers/untooled.js +++ b/server/utils/agents/aibitat/providers/helpers/untooled.js @@ -102,48 +102,34 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; return { valid: true, reason: null }; } - async functionCall(messages, functions) { + async functionCall(messages, functions, chatCb = null) { const history = [...messages].filter((msg) => ["user", "assistant"].includes(msg.role) ); if (history[history.length - 1].role !== "user") return null; - - const response = await this.client.chat.completions - .create({ - model: this.model, - temperature: 0, - messages: [ - { - content: `You are a program which picks the most optimal function and parameters to call. -DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. -When a function is selection, respond in JSON with no additional text. -When there is no relevant function to call - return with a regular chat text response. -Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query. - -All JSON responses should have two keys. -'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc.. -'arguments': this is an object with the function properties to invoke the function. -DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES. - -Here are the available tools you can use an examples of a query and response so you can understand how each one works. -${this.showcaseFunctions(functions)} - -Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`, - role: "system", - }, - ...history, - ], - }) - .then((result) => { - if (!result.hasOwnProperty("choices")) - throw new Error("LMStudio chat: No results!"); - if (result.choices.length === 0) - throw new Error("LMStudio chat: No results length!"); - return result.choices[0].message.content; - }) - .catch((_) => { - return null; - }); + const response = await chatCb({ + messages: [ + { + content: `You are a program which picks the most optimal function and parameters to call. + DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. + When a function is selection, respond in JSON with no additional text. + When there is no relevant function to call - return with a regular chat text response. + Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query. + + All JSON responses should have two keys. + 'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc.. + 'arguments': this is an object with the function properties to invoke the function. + DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES. + + Here are the available tools you can use an examples of a query and response so you can understand how each one works. + ${this.showcaseFunctions(functions)} + + Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`, + role: "system", + }, + ...history, + ], + }); const call = safeJsonParse(response, null); if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index ebe4de33f..fda8b5136 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -1,9 +1,11 @@ const OpenAIProvider = require("./openai.js"); const AnthropicProvider = require("./anthropic.js"); const LMStudioProvider = require("./lmstudio.js"); +const OllamaProvider = require("./ollama.js"); module.exports = { OpenAIProvider, AnthropicProvider, LMStudioProvider, + OllamaProvider, }; diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index 49387e43b..d3aa4346a 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -27,6 +27,25 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LMStudio chat: No results!"); + if (result.choices.length === 0) + throw new Error("LMStudio chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + /** * Create a completion based on the received messages. * @@ -38,7 +57,11 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { try { let completion; if (functions.length > 0) { - const { toolCall, text } = await this.functionCall(messages, functions); + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); if (toolCall !== null) { this.providerLog(`Valid tool call found - running ${toolCall.name}.`); diff --git a/server/utils/agents/aibitat/providers/ollama.js b/server/utils/agents/aibitat/providers/ollama.js new file mode 100644 index 000000000..d52d80caa --- /dev/null +++ b/server/utils/agents/aibitat/providers/ollama.js @@ -0,0 +1,107 @@ +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); +const { Ollama } = require("ollama"); + +/** + * The provider for the Ollama provider. + */ +class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { + // options = {}, + model = null, + } = config; + + super(); + this._client = new Ollama({ host: process.env.OLLAMA_BASE_PATH }); + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + const response = await this.client.chat({ + model: this.model, + messages, + options: { + temperature: 0, + }, + }); + return response?.message?.content || null; + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat({ + model: this.model, + messages: this.cleanMsgs(messages), + options: { + use_mlock: true, + temperature: 0.5, + }, + }); + completion = response.message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LMStudio has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = OllamaProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 5e54c0b3f..e18b8b7bb 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -79,7 +79,11 @@ class AgentHandler { break; case "lmstudio": if (!process.env.LMSTUDIO_BASE_PATH) - throw new Error("LMStudio bash path must be provided to use agents."); + throw new Error("LMStudio base path must be provided to use agents."); + break; + case "ollama": + if (!process.env.OLLAMA_BASE_PATH) + throw new Error("Ollama base path must be provided to use agents."); break; default: throw new Error("No provider found to power agent cluster."); @@ -94,6 +98,8 @@ class AgentHandler { return "claude-3-sonnet-20240229"; case "lmstudio": return "server-default"; + case "ollama": + return "llama3:latest"; default: return "unknown"; } diff --git a/server/utils/http/index.js b/server/utils/http/index.js index 1fc8c5b96..6400c36bc 100644 --- a/server/utils/http/index.js +++ b/server/utils/http/index.js @@ -4,6 +4,7 @@ process.env.NODE_ENV === "development" const JWT = require("jsonwebtoken"); const { User } = require("../../models/user"); const { jsonrepair } = require("jsonrepair"); +const extract = require("extract-json-from-string"); function reqBody(request) { return typeof request.body === "string" @@ -67,8 +68,6 @@ function safeJsonParse(jsonString, fallback = null) { return JSON.parse(jsonString); } catch {} - // If the jsonString does not look like an Obj or Array, dont attempt - // to repair it. if (jsonString?.startsWith("[") || jsonString?.startsWith("{")) { try { const repairedJson = jsonrepair(jsonString); @@ -76,6 +75,10 @@ function safeJsonParse(jsonString, fallback = null) { } catch {} } + try { + return extract(jsonString)[0]; + } catch {} + return fallback; } diff --git a/server/yarn.lock b/server/yarn.lock index 1911849d4..5edd09a35 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -2678,6 +2678,11 @@ extract-files@^9.0.0: resolved "https://registry.yarnpkg.com/extract-files/-/extract-files-9.0.0.tgz#8a7744f2437f81f5ed3250ed9f1550de902fe54a" integrity sha512-CvdFfHkC95B4bBBk36hcEmvdR2awOdhhVUYH6S/zrVj3477zven/fJMYg7121h4T1xHZC+tetUpubpAhxwI7hQ== +extract-json-from-string@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/extract-json-from-string/-/extract-json-from-string-1.0.1.tgz#5001f17e6c905826dcd5989564e130959de60c96" + integrity sha512-xfQOSFYbELVs9QVkKsV9FZAjlAmXQ2SLR6FpfFX1kpn4QAvaGBJlrnVOblMLwrLPYc26H+q9qxo6JTd4E7AwgQ== + extract-zip@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/extract-zip/-/extract-zip-2.0.1.tgz#663dca56fe46df890d5f131ef4a06d22bb8ba13a" @@ -4560,6 +4565,13 @@ octokit@^3.1.0: "@octokit/request-error" "^5.0.0" "@octokit/types" "^12.0.0" +ollama@^0.5.0: + version "0.5.0" + resolved "https://registry.yarnpkg.com/ollama/-/ollama-0.5.0.tgz#cb9bc709d4d3278c9f484f751b0d9b98b06f4859" + integrity sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw== + dependencies: + whatwg-fetch "^3.6.20" + on-finished@2.4.1: version "2.4.1" resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.4.1.tgz#58c8c44116e54845ad57f14ab10b03533184ac3f" @@ -5980,7 +5992,7 @@ webidl-conversions@^3.0.0: resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871" integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ== -whatwg-fetch@^3.4.1: +whatwg-fetch@^3.4.1, whatwg-fetch@^3.6.20: version "3.6.20" resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz#580ce6d791facec91d37c72890995a0b48d31c70" integrity sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg== -- GitLab