diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index 408d60a026f13303cfc3fe749313224812965a5b..fcb12d94d5d0cf8cf720202d93bd7126c8735ac1 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -5,8 +5,8 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import AgentModelSelection from "../AgentModelSelection"; -const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio"]; -const WARN_PERFORMANCE = ["lmstudio"]; +const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; +const WARN_PERFORMANCE = ["lmstudio", "ollama"]; const LLM_DEFAULT = { name: "Please make a selection", diff --git a/server/package.json b/server/package.json index afe6926de0382e1d710a29f686e92da71870dcaa..edee71b023c0532361f71077562e193d1aa662e6 100644 --- a/server/package.json +++ b/server/package.json @@ -46,6 +46,7 @@ "dotenv": "^16.0.3", "express": "^4.18.2", "express-ws": "^5.0.2", + "extract-json-from-string": "^1.0.1", "extract-zip": "^2.0.1", "graphql": "^16.7.1", "joi": "^17.11.0", @@ -59,6 +60,7 @@ "multer": "^1.4.5-lts.1", "node-html-markdown": "^1.3.0", "node-llama-cpp": "^2.8.0", + "ollama": "^0.5.0", "openai": "4.38.5", "pinecone-client": "^1.1.0", "pluralize": "^8.0.0", diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index c3ad5428f5775ca912db2136a93e03b87c131b46..9cf2170b7cfcd0994c030728cc2d4e218c8a10cc 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -741,6 +741,8 @@ ${this.getHistory({ to: route.to }) return new Providers.AnthropicProvider({ model: config.model }); case "lmstudio": return new Providers.LMStudioProvider({}); + case "ollama": + return new Providers.OllamaProvider({ model: config.model }); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js index a84aad77c922cbcb1f31efa6ee62a213f5af697b..37ecb5599f58f10c34420aa74c1d6b761bc3e65f 100644 --- a/server/utils/agents/aibitat/providers/helpers/untooled.js +++ b/server/utils/agents/aibitat/providers/helpers/untooled.js @@ -102,48 +102,34 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; return { valid: true, reason: null }; } - async functionCall(messages, functions) { + async functionCall(messages, functions, chatCb = null) { const history = [...messages].filter((msg) => ["user", "assistant"].includes(msg.role) ); if (history[history.length - 1].role !== "user") return null; - - const response = await this.client.chat.completions - .create({ - model: this.model, - temperature: 0, - messages: [ - { - content: `You are a program which picks the most optimal function and parameters to call. -DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. -When a function is selection, respond in JSON with no additional text. -When there is no relevant function to call - return with a regular chat text response. -Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query. - -All JSON responses should have two keys. -'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc.. -'arguments': this is an object with the function properties to invoke the function. -DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES. - -Here are the available tools you can use an examples of a query and response so you can understand how each one works. -${this.showcaseFunctions(functions)} - -Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`, - role: "system", - }, - ...history, - ], - }) - .then((result) => { - if (!result.hasOwnProperty("choices")) - throw new Error("LMStudio chat: No results!"); - if (result.choices.length === 0) - throw new Error("LMStudio chat: No results length!"); - return result.choices[0].message.content; - }) - .catch((_) => { - return null; - }); + const response = await chatCb({ + messages: [ + { + content: `You are a program which picks the most optimal function and parameters to call. + DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. + When a function is selection, respond in JSON with no additional text. + When there is no relevant function to call - return with a regular chat text response. + Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query. + + All JSON responses should have two keys. + 'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc.. + 'arguments': this is an object with the function properties to invoke the function. + DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES. + + Here are the available tools you can use an examples of a query and response so you can understand how each one works. + ${this.showcaseFunctions(functions)} + + Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`, + role: "system", + }, + ...history, + ], + }); const call = safeJsonParse(response, null); if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index ebe4de33f06a1128e3f3c97591136e5173617200..fda8b51360f8c5a909fc51703d6780457610a3e9 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -1,9 +1,11 @@ const OpenAIProvider = require("./openai.js"); const AnthropicProvider = require("./anthropic.js"); const LMStudioProvider = require("./lmstudio.js"); +const OllamaProvider = require("./ollama.js"); module.exports = { OpenAIProvider, AnthropicProvider, LMStudioProvider, + OllamaProvider, }; diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index 49387e43b41b0ec27481c25712792da987db705f..d3aa4346a7111d7b83dad15c99bce3fb2b2fbb1b 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -27,6 +27,25 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LMStudio chat: No results!"); + if (result.choices.length === 0) + throw new Error("LMStudio chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + /** * Create a completion based on the received messages. * @@ -38,7 +57,11 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { try { let completion; if (functions.length > 0) { - const { toolCall, text } = await this.functionCall(messages, functions); + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); if (toolCall !== null) { this.providerLog(`Valid tool call found - running ${toolCall.name}.`); diff --git a/server/utils/agents/aibitat/providers/ollama.js b/server/utils/agents/aibitat/providers/ollama.js new file mode 100644 index 0000000000000000000000000000000000000000..d52d80caa08a887cd9f3a89633737c2ee437a2e9 --- /dev/null +++ b/server/utils/agents/aibitat/providers/ollama.js @@ -0,0 +1,107 @@ +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); +const { Ollama } = require("ollama"); + +/** + * The provider for the Ollama provider. + */ +class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { + // options = {}, + model = null, + } = config; + + super(); + this._client = new Ollama({ host: process.env.OLLAMA_BASE_PATH }); + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + const response = await this.client.chat({ + model: this.model, + messages, + options: { + temperature: 0, + }, + }); + return response?.message?.content || null; + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat({ + model: this.model, + messages: this.cleanMsgs(messages), + options: { + use_mlock: true, + temperature: 0.5, + }, + }); + completion = response.message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LMStudio has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = OllamaProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 5e54c0b3f67af47d8a5e7bd3fe750dad2a07025b..e18b8b7bb945187f370c37afcca6953cb5ebd350 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -79,7 +79,11 @@ class AgentHandler { break; case "lmstudio": if (!process.env.LMSTUDIO_BASE_PATH) - throw new Error("LMStudio bash path must be provided to use agents."); + throw new Error("LMStudio base path must be provided to use agents."); + break; + case "ollama": + if (!process.env.OLLAMA_BASE_PATH) + throw new Error("Ollama base path must be provided to use agents."); break; default: throw new Error("No provider found to power agent cluster."); @@ -94,6 +98,8 @@ class AgentHandler { return "claude-3-sonnet-20240229"; case "lmstudio": return "server-default"; + case "ollama": + return "llama3:latest"; default: return "unknown"; } diff --git a/server/utils/http/index.js b/server/utils/http/index.js index 1fc8c5b961cd908dde6a253285fbe98285fcfb6c..6400c36bcf72725fc3cd49d4a2b5682e7e4f0f7e 100644 --- a/server/utils/http/index.js +++ b/server/utils/http/index.js @@ -4,6 +4,7 @@ process.env.NODE_ENV === "development" const JWT = require("jsonwebtoken"); const { User } = require("../../models/user"); const { jsonrepair } = require("jsonrepair"); +const extract = require("extract-json-from-string"); function reqBody(request) { return typeof request.body === "string" @@ -67,8 +68,6 @@ function safeJsonParse(jsonString, fallback = null) { return JSON.parse(jsonString); } catch {} - // If the jsonString does not look like an Obj or Array, dont attempt - // to repair it. if (jsonString?.startsWith("[") || jsonString?.startsWith("{")) { try { const repairedJson = jsonrepair(jsonString); @@ -76,6 +75,10 @@ function safeJsonParse(jsonString, fallback = null) { } catch {} } + try { + return extract(jsonString)[0]; + } catch {} + return fallback; } diff --git a/server/yarn.lock b/server/yarn.lock index 1911849d453a6533e1ac44d060791d0189865daf..5edd09a351fbc0b1aebfa9c35db1d60574c4205b 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -2678,6 +2678,11 @@ extract-files@^9.0.0: resolved "https://registry.yarnpkg.com/extract-files/-/extract-files-9.0.0.tgz#8a7744f2437f81f5ed3250ed9f1550de902fe54a" integrity sha512-CvdFfHkC95B4bBBk36hcEmvdR2awOdhhVUYH6S/zrVj3477zven/fJMYg7121h4T1xHZC+tetUpubpAhxwI7hQ== +extract-json-from-string@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/extract-json-from-string/-/extract-json-from-string-1.0.1.tgz#5001f17e6c905826dcd5989564e130959de60c96" + integrity sha512-xfQOSFYbELVs9QVkKsV9FZAjlAmXQ2SLR6FpfFX1kpn4QAvaGBJlrnVOblMLwrLPYc26H+q9qxo6JTd4E7AwgQ== + extract-zip@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/extract-zip/-/extract-zip-2.0.1.tgz#663dca56fe46df890d5f131ef4a06d22bb8ba13a" @@ -4560,6 +4565,13 @@ octokit@^3.1.0: "@octokit/request-error" "^5.0.0" "@octokit/types" "^12.0.0" +ollama@^0.5.0: + version "0.5.0" + resolved "https://registry.yarnpkg.com/ollama/-/ollama-0.5.0.tgz#cb9bc709d4d3278c9f484f751b0d9b98b06f4859" + integrity sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw== + dependencies: + whatwg-fetch "^3.6.20" + on-finished@2.4.1: version "2.4.1" resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.4.1.tgz#58c8c44116e54845ad57f14ab10b03533184ac3f" @@ -5980,7 +5992,7 @@ webidl-conversions@^3.0.0: resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871" integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ== -whatwg-fetch@^3.4.1: +whatwg-fetch@^3.4.1, whatwg-fetch@^3.6.20: version "3.6.20" resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz#580ce6d791facec91d37c72890995a0b48d31c70" integrity sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==