From 8422f9254278f6430a350b808b877d42e6061ec0 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Wed, 8 May 2024 15:17:54 -0700
Subject: [PATCH] Agent support for LLMs with no function calling (#1295)

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* azure, groq, koboldcpp agent support complete + WIP togetherai

* WIP gemini agent support

* WIP gemini blocked and will not fix for now

* azure fix

* merge fix

* add localai agent support

* azure untooled agent support

* merge fix

* refactor implementation of several agent provideers

* update bad merge comment

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
---
 .vscode/settings.json                         |   1 +
 .../AgentConfig/AgentLLMSelection/index.jsx   |  21 +++-
 server/utils/agents/aibitat/index.js          |  14 ++-
 .../agents/aibitat/providers/ai-provider.js   |   3 +
 .../utils/agents/aibitat/providers/azure.js   | 105 ++++++++++++++++
 server/utils/agents/aibitat/providers/groq.js | 110 +++++++++++++++++
 .../aibitat/providers/helpers/untooled.js     |   3 +-
 .../utils/agents/aibitat/providers/index.js   |  10 ++
 .../agents/aibitat/providers/koboldcpp.js     | 113 +++++++++++++++++
 .../agents/aibitat/providers/lmstudio.js      |   2 +-
 .../utils/agents/aibitat/providers/localai.js | 114 ++++++++++++++++++
 .../agents/aibitat/providers/togetherai.js    | 113 +++++++++++++++++
 server/utils/agents/index.js                  |  42 +++++++
 server/utils/helpers/customModels.js          |   2 +-
 14 files changed, 645 insertions(+), 8 deletions(-)
 create mode 100644 server/utils/agents/aibitat/providers/azure.js
 create mode 100644 server/utils/agents/aibitat/providers/groq.js
 create mode 100644 server/utils/agents/aibitat/providers/koboldcpp.js
 create mode 100644 server/utils/agents/aibitat/providers/localai.js
 create mode 100644 server/utils/agents/aibitat/providers/togetherai.js

diff --git a/.vscode/settings.json b/.vscode/settings.json
index f850bbb00..eecaa83fd 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -28,6 +28,7 @@
     "openrouter",
     "Qdrant",
     "Serper",
+    "togetherai",
     "vectordbs",
     "Weaviate",
     "Zilliz"
diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
index fcb12d94d..400eef02d 100644
--- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
@@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
 import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
 import AgentModelSelection from "../AgentModelSelection";
 
-const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
-const WARN_PERFORMANCE = ["lmstudio", "ollama"];
+const ENABLED_PROVIDERS = [
+  "openai",
+  "anthropic",
+  "lmstudio",
+  "ollama",
+  "localai",
+  "groq",
+  "azure",
+  "koboldcpp",
+  "togetherai",
+];
+const WARN_PERFORMANCE = [
+  "lmstudio",
+  "groq",
+  "azure",
+  "koboldcpp",
+  "ollama",
+  "localai",
+];
 
 const LLM_DEFAULT = {
   name: "Please make a selection",
diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js
index 9cf2170b7..3413bd359 100644
--- a/server/utils/agents/aibitat/index.js
+++ b/server/utils/agents/aibitat/index.js
@@ -480,7 +480,7 @@ Read the following conversation.
 CHAT HISTORY
 ${history.map((c) => `@${c.from}: ${c.content}`).join("\n")}
 
-Then select the next role from that is going to speak next. 
+Then select the next role from that is going to speak next.
 Only return the role.
 `,
       },
@@ -522,7 +522,7 @@ Only return the role.
         ? [
             {
               role: "user",
-              content: `You are in a whatsapp group. Read the following conversation and then reply. 
+              content: `You are in a whatsapp group. Read the following conversation and then reply.
 Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself.
 
 CHAT HISTORY
@@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
         return new Providers.LMStudioProvider({});
       case "ollama":
         return new Providers.OllamaProvider({ model: config.model });
+      case "groq":
+        return new Providers.GroqProvider({ model: config.model });
+      case "togetherai":
+        return new Providers.TogetherAIProvider({ model: config.model });
+      case "azure":
+        return new Providers.AzureOpenAiProvider({ model: config.model });
+      case "koboldcpp":
+        return new Providers.KoboldCPPProvider({});
+      case "localai":
+        return new Providers.LocalAIProvider({ model: config.model });
 
       default:
         throw new Error(
diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js
index 0e871b36e..91a81ebfa 100644
--- a/server/utils/agents/aibitat/providers/ai-provider.js
+++ b/server/utils/agents/aibitat/providers/ai-provider.js
@@ -58,6 +58,9 @@ class Provider {
     }
   }
 
+  // For some providers we may want to override the system prompt to be more verbose.
+  // Currently we only do this for lmstudio, but we probably will want to expand this even more
+  // to any Untooled LLM.
   static systemPrompt(provider = null) {
     switch (provider) {
       case "lmstudio":
diff --git a/server/utils/agents/aibitat/providers/azure.js b/server/utils/agents/aibitat/providers/azure.js
new file mode 100644
index 000000000..cdcf7618b
--- /dev/null
+++ b/server/utils/agents/aibitat/providers/azure.js
@@ -0,0 +1,105 @@
+const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
+const Provider = require("./ai-provider.js");
+const InheritMultiple = require("./helpers/classes.js");
+const UnTooled = require("./helpers/untooled.js");
+
+/**
+ * The provider for the Azure OpenAI API.
+ */
+class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
+  model;
+
+  constructor(_config = {}) {
+    super();
+    const client = new OpenAIClient(
+      process.env.AZURE_OPENAI_ENDPOINT,
+      new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
+    );
+    this._client = client;
+    this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
+    this.verbose = true;
+  }
+
+  get client() {
+    return this._client;
+  }
+
+  async #handleFunctionCallChat({ messages = [] }) {
+    return await this.client
+      .getChatCompletions(this.model, messages, {
+        temperature: 0,
+      })
+      .then((result) => {
+        if (!result.hasOwnProperty("choices"))
+          throw new Error("Azure OpenAI chat: No results!");
+        if (result.choices.length === 0)
+          throw new Error("Azure OpenAI chat: No results length!");
+        return result.choices[0].message.content;
+      })
+      .catch((_) => {
+        return null;
+      });
+  }
+
+  /**
+   * Create a completion based on the received messages.
+   *
+   * @param messages A list of messages to send to the API.
+   * @param functions
+   * @returns The completion.
+   */
+  async complete(messages, functions = null) {
+    try {
+      let completion;
+      if (functions.length > 0) {
+        const { toolCall, text } = await this.functionCall(
+          messages,
+          functions,
+          this.#handleFunctionCallChat.bind(this)
+        );
+        if (toolCall !== null) {
+          this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
+          this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
+          return {
+            result: null,
+            functionCall: {
+              name: toolCall.name,
+              arguments: toolCall.arguments,
+            },
+            cost: 0,
+          };
+        }
+        completion = { content: text };
+      }
+      if (!completion?.content) {
+        this.providerLog(
+          "Will assume chat completion without tool call inputs."
+        );
+        const response = await this.client.getChatCompletions(
+          this.model,
+          this.cleanMsgs(messages),
+          {
+            temperature: 0.7,
+          }
+        );
+        completion = response.choices[0].message;
+      }
+      return { result: completion.content, cost: 0 };
+    } catch (error) {
+      throw error;
+    }
+  }
+
+  /**
+   * Get the cost of the completion.
+   * Stubbed since Azure OpenAI has no public cost basis.
+   *
+   * @param _usage The completion to get the cost for.
+   * @returns The cost of the completion.
+   */
+  getCost(_usage) {
+    return 0;
+  }
+}
+
+module.exports = AzureOpenAiProvider;
diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js
new file mode 100644
index 000000000..3b87ba510
--- /dev/null
+++ b/server/utils/agents/aibitat/providers/groq.js
@@ -0,0 +1,110 @@
+const OpenAI = require("openai");
+const Provider = require("./ai-provider.js");
+const { RetryError } = require("../error.js");
+
+/**
+ * The provider for the Groq provider.
+ */
+class GroqProvider extends Provider {
+  model;
+
+  constructor(config = {}) {
+    const { model = "llama3-8b-8192" } = config;
+    const client = new OpenAI({
+      baseURL: "https://api.groq.com/openai/v1",
+      apiKey: process.env.GROQ_API_KEY,
+      maxRetries: 3,
+    });
+    super(client);
+    this.model = model;
+    this.verbose = true;
+  }
+
+  /**
+   * Create a completion based on the received messages.
+   *
+   * @param messages A list of messages to send to the API.
+   * @param functions
+   * @returns The completion.
+   */
+  async complete(messages, functions = null) {
+    try {
+      const response = await this.client.chat.completions.create({
+        model: this.model,
+        // stream: true,
+        messages,
+        ...(Array.isArray(functions) && functions?.length > 0
+          ? { functions }
+          : {}),
+      });
+
+      // Right now, we only support one completion,
+      // so we just take the first one in the list
+      const completion = response.choices[0].message;
+      const cost = this.getCost(response.usage);
+      // treat function calls
+      if (completion.function_call) {
+        let functionArgs = {};
+        try {
+          functionArgs = JSON.parse(completion.function_call.arguments);
+        } catch (error) {
+          // call the complete function again in case it gets a json error
+          return this.complete(
+            [
+              ...messages,
+              {
+                role: "function",
+                name: completion.function_call.name,
+                function_call: completion.function_call,
+                content: error?.message,
+              },
+            ],
+            functions
+          );
+        }
+
+        // console.log(completion, { functionArgs })
+        return {
+          result: null,
+          functionCall: {
+            name: completion.function_call.name,
+            arguments: functionArgs,
+          },
+          cost,
+        };
+      }
+
+      return {
+        result: completion.content,
+        cost,
+      };
+    } catch (error) {
+      // If invalid Auth error we need to abort because no amount of waiting
+      // will make auth better.
+      if (error instanceof OpenAI.AuthenticationError) throw error;
+
+      if (
+        error instanceof OpenAI.RateLimitError ||
+        error instanceof OpenAI.InternalServerError ||
+        error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
+      ) {
+        throw new RetryError(error.message);
+      }
+
+      throw error;
+    }
+  }
+
+  /**
+   * Get the cost of the completion.
+   *
+   * @param _usage The completion to get the cost for.
+   * @returns The cost of the completion.
+   * Stubbed since Groq has no cost basis.
+   */
+  getCost(_usage) {
+    return 0;
+  }
+}
+
+module.exports = GroqProvider;
diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js
index 37ecb5599..11fbfec8b 100644
--- a/server/utils/agents/aibitat/providers/helpers/untooled.js
+++ b/server/utils/agents/aibitat/providers/helpers/untooled.js
@@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
     const response = await chatCb({
       messages: [
         {
-          content: `You are a program which picks the most optimal function and parameters to call. 
+          content: `You are a program which picks the most optimal function and parameters to call.
       DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
       When a function is selection, respond in JSON with no additional text.
       When there is no relevant function to call - return with a regular chat text response.
@@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
         ...history,
       ],
     });
-
     const call = safeJsonParse(response, null);
     if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.
 
diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js
index fda8b5136..6f8a2da0b 100644
--- a/server/utils/agents/aibitat/providers/index.js
+++ b/server/utils/agents/aibitat/providers/index.js
@@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
 const AnthropicProvider = require("./anthropic.js");
 const LMStudioProvider = require("./lmstudio.js");
 const OllamaProvider = require("./ollama.js");
+const GroqProvider = require("./groq.js");
+const TogetherAIProvider = require("./togetherai.js");
+const AzureOpenAiProvider = require("./azure.js");
+const KoboldCPPProvider = require("./koboldcpp.js");
+const LocalAIProvider = require("./localai.js");
 
 module.exports = {
   OpenAIProvider,
   AnthropicProvider,
   LMStudioProvider,
   OllamaProvider,
+  GroqProvider,
+  TogetherAIProvider,
+  AzureOpenAiProvider,
+  KoboldCPPProvider,
+  LocalAIProvider,
 };
diff --git a/server/utils/agents/aibitat/providers/koboldcpp.js b/server/utils/agents/aibitat/providers/koboldcpp.js
new file mode 100644
index 000000000..77088263c
--- /dev/null
+++ b/server/utils/agents/aibitat/providers/koboldcpp.js
@@ -0,0 +1,113 @@
+const OpenAI = require("openai");
+const Provider = require("./ai-provider.js");
+const InheritMultiple = require("./helpers/classes.js");
+const UnTooled = require("./helpers/untooled.js");
+
+/**
+ * The provider for the KoboldCPP provider.
+ */
+class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
+  model;
+
+  constructor(_config = {}) {
+    super();
+    const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
+    const client = new OpenAI({
+      baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
+      apiKey: null,
+      maxRetries: 3,
+    });
+
+    this._client = client;
+    this.model = model;
+    this.verbose = true;
+  }
+
+  get client() {
+    return this._client;
+  }
+
+  async #handleFunctionCallChat({ messages = [] }) {
+    return await this.client.chat.completions
+      .create({
+        model: this.model,
+        temperature: 0,
+        messages,
+      })
+      .then((result) => {
+        if (!result.hasOwnProperty("choices"))
+          throw new Error("KoboldCPP chat: No results!");
+        if (result.choices.length === 0)
+          throw new Error("KoboldCPP chat: No results length!");
+        return result.choices[0].message.content;
+      })
+      .catch((_) => {
+        return null;
+      });
+  }
+
+  /**
+   * Create a completion based on the received messages.
+   *
+   * @param messages A list of messages to send to the API.
+   * @param functions
+   * @returns The completion.
+   */
+  async complete(messages, functions = null) {
+    try {
+      let completion;
+      if (functions.length > 0) {
+        const { toolCall, text } = await this.functionCall(
+          messages,
+          functions,
+          this.#handleFunctionCallChat.bind(this)
+        );
+
+        if (toolCall !== null) {
+          this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
+          this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
+          return {
+            result: null,
+            functionCall: {
+              name: toolCall.name,
+              arguments: toolCall.arguments,
+            },
+            cost: 0,
+          };
+        }
+        completion = { content: text };
+      }
+
+      if (!completion?.content) {
+        this.providerLog(
+          "Will assume chat completion without tool call inputs."
+        );
+        const response = await this.client.chat.completions.create({
+          model: this.model,
+          messages: this.cleanMsgs(messages),
+        });
+        completion = response.choices[0].message;
+      }
+
+      return {
+        result: completion.content,
+        cost: 0,
+      };
+    } catch (error) {
+      throw error;
+    }
+  }
+
+  /**
+   * Get the cost of the completion.
+   *
+   * @param _usage The completion to get the cost for.
+   * @returns The cost of the completion.
+   * Stubbed since KoboldCPP has no cost basis.
+   */
+  getCost(_usage) {
+    return 0;
+  }
+}
+
+module.exports = KoboldCPPProvider;
diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js
index d3aa4346a..f5c4a2e82 100644
--- a/server/utils/agents/aibitat/providers/lmstudio.js
+++ b/server/utils/agents/aibitat/providers/lmstudio.js
@@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
       baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
       apiKey: null,
       maxRetries: 3,
-      model,
     });
+
     this._client = client;
     this.model = model;
     this.verbose = true;
diff --git a/server/utils/agents/aibitat/providers/localai.js b/server/utils/agents/aibitat/providers/localai.js
new file mode 100644
index 000000000..161172c21
--- /dev/null
+++ b/server/utils/agents/aibitat/providers/localai.js
@@ -0,0 +1,114 @@
+const OpenAI = require("openai");
+const Provider = require("./ai-provider.js");
+const InheritMultiple = require("./helpers/classes.js");
+const UnTooled = require("./helpers/untooled.js");
+
+/**
+ * The provider for the LocalAI provider.
+ */
+class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
+  model;
+
+  constructor(config = {}) {
+    const { model = null } = config;
+    super();
+    const client = new OpenAI({
+      baseURL: process.env.LOCAL_AI_BASE_PATH,
+      apiKey: process.env.LOCAL_AI_API_KEY ?? null,
+      maxRetries: 3,
+    });
+
+    this._client = client;
+    this.model = model;
+    this.verbose = true;
+  }
+
+  get client() {
+    return this._client;
+  }
+
+  async #handleFunctionCallChat({ messages = [] }) {
+    return await this.client.chat.completions
+      .create({
+        model: this.model,
+        temperature: 0,
+        messages,
+      })
+      .then((result) => {
+        if (!result.hasOwnProperty("choices"))
+          throw new Error("LocalAI chat: No results!");
+
+        if (result.choices.length === 0)
+          throw new Error("LocalAI chat: No results length!");
+
+        return result.choices[0].message.content;
+      })
+      .catch((_) => {
+        return null;
+      });
+  }
+
+  /**
+   * Create a completion based on the received messages.
+   *
+   * @param messages A list of messages to send to the API.
+   * @param functions
+   * @returns The completion.
+   */
+  async complete(messages, functions = null) {
+    try {
+      let completion;
+
+      if (functions.length > 0) {
+        const { toolCall, text } = await this.functionCall(
+          messages,
+          functions,
+          this.#handleFunctionCallChat.bind(this)
+        );
+
+        if (toolCall !== null) {
+          this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
+          this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
+          return {
+            result: null,
+            functionCall: {
+              name: toolCall.name,
+              arguments: toolCall.arguments,
+            },
+            cost: 0,
+          };
+        }
+
+        completion = { content: text };
+      }
+
+      if (!completion?.content) {
+        this.providerLog(
+          "Will assume chat completion without tool call inputs."
+        );
+        const response = await this.client.chat.completions.create({
+          model: this.model,
+          messages: this.cleanMsgs(messages),
+        });
+        completion = response.choices[0].message;
+      }
+
+      return { result: completion.content, cost: 0 };
+    } catch (error) {
+      throw error;
+    }
+  }
+
+  /**
+   * Get the cost of the completion.
+   *
+   * @param _usage The completion to get the cost for.
+   * @returns The cost of the completion.
+   * Stubbed since LocalAI has no cost basis.
+   */
+  getCost(_usage) {
+    return 0;
+  }
+}
+
+module.exports = LocalAiProvider;
diff --git a/server/utils/agents/aibitat/providers/togetherai.js b/server/utils/agents/aibitat/providers/togetherai.js
new file mode 100644
index 000000000..4ea5e11c2
--- /dev/null
+++ b/server/utils/agents/aibitat/providers/togetherai.js
@@ -0,0 +1,113 @@
+const OpenAI = require("openai");
+const Provider = require("./ai-provider.js");
+const InheritMultiple = require("./helpers/classes.js");
+const UnTooled = require("./helpers/untooled.js");
+
+/**
+ * The provider for the TogetherAI provider.
+ */
+class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
+  model;
+
+  constructor(config = {}) {
+    const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
+    super();
+    const client = new OpenAI({
+      baseURL: "https://api.together.xyz/v1",
+      apiKey: process.env.TOGETHER_AI_API_KEY,
+      maxRetries: 3,
+    });
+
+    this._client = client;
+    this.model = model;
+    this.verbose = true;
+  }
+
+  get client() {
+    return this._client;
+  }
+
+  async #handleFunctionCallChat({ messages = [] }) {
+    return await this.client.chat.completions
+      .create({
+        model: this.model,
+        temperature: 0,
+        messages,
+      })
+      .then((result) => {
+        if (!result.hasOwnProperty("choices"))
+          throw new Error("LMStudio chat: No results!");
+        if (result.choices.length === 0)
+          throw new Error("LMStudio chat: No results length!");
+        return result.choices[0].message.content;
+      })
+      .catch((_) => {
+        return null;
+      });
+  }
+
+  /**
+   * Create a completion based on the received messages.
+   *
+   * @param messages A list of messages to send to the API.
+   * @param functions
+   * @returns The completion.
+   */
+  async complete(messages, functions = null) {
+    try {
+      let completion;
+      if (functions.length > 0) {
+        const { toolCall, text } = await this.functionCall(
+          messages,
+          functions,
+          this.#handleFunctionCallChat.bind(this)
+        );
+
+        if (toolCall !== null) {
+          this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
+          this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
+          return {
+            result: null,
+            functionCall: {
+              name: toolCall.name,
+              arguments: toolCall.arguments,
+            },
+            cost: 0,
+          };
+        }
+        completion = { content: text };
+      }
+
+      if (!completion?.content) {
+        this.providerLog(
+          "Will assume chat completion without tool call inputs."
+        );
+        const response = await this.client.chat.completions.create({
+          model: this.model,
+          messages: this.cleanMsgs(messages),
+        });
+        completion = response.choices[0].message;
+      }
+
+      return {
+        result: completion.content,
+        cost: 0,
+      };
+    } catch (error) {
+      throw error;
+    }
+  }
+
+  /**
+   * Get the cost of the completion.
+   *
+   * @param _usage The completion to get the cost for.
+   * @returns The cost of the completion.
+   * Stubbed since LMStudio has no cost basis.
+   */
+  getCost(_usage) {
+    return 0;
+  }
+}
+
+module.exports = TogetherAIProvider;
diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js
index e18b8b7bb..768ad8199 100644
--- a/server/utils/agents/index.js
+++ b/server/utils/agents/index.js
@@ -85,6 +85,36 @@ class AgentHandler {
         if (!process.env.OLLAMA_BASE_PATH)
           throw new Error("Ollama base path must be provided to use agents.");
         break;
+      case "groq":
+        if (!process.env.GROQ_API_KEY)
+          throw new Error("Groq API key must be provided to use agents.");
+        break;
+      case "togetherai":
+        if (!process.env.TOGETHER_AI_API_KEY)
+          throw new Error("TogetherAI API key must be provided to use agents.");
+        break;
+      case "azure":
+        if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
+          throw new Error(
+            "Azure OpenAI API endpoint and key must be provided to use agents."
+          );
+        break;
+      case "koboldcpp":
+        if (!process.env.KOBOLD_CPP_BASE_PATH)
+          throw new Error(
+            "KoboldCPP must have a valid base path to use for the api."
+          );
+        break;
+      case "localai":
+        if (!process.env.LOCAL_AI_BASE_PATH)
+          throw new Error(
+            "LocalAI must have a valid base path to use for the api."
+          );
+        break;
+      case "gemini":
+        if (!process.env.GEMINI_API_KEY)
+          throw new Error("Gemini API key must be provided to use agents.");
+        break;
       default:
         throw new Error("No provider found to power agent cluster.");
     }
@@ -100,6 +130,18 @@ class AgentHandler {
         return "server-default";
       case "ollama":
         return "llama3:latest";
+      case "groq":
+        return "llama3-70b-8192";
+      case "togetherai":
+        return "mistralai/Mixtral-8x7B-Instruct-v0.1";
+      case "azure":
+        return "gpt-3.5-turbo";
+      case "koboldcpp":
+        return null;
+      case "gemini":
+        return "gemini-pro";
+      case "localai":
+        return null;
       default:
         return "unknown";
     }
diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js
index ce690ae47..3743ffad7 100644
--- a/server/utils/helpers/customModels.js
+++ b/server/utils/helpers/customModels.js
@@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
   try {
     const { OpenAI: OpenAIApi } = require("openai");
     const openai = new OpenAIApi({
-      baseURL: basePath || process.env.LMSTUDIO_BASE_PATH,
+      baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
       apiKey: null,
     });
     const models = await openai.models
-- 
GitLab