From 9ace0e67e68aa5dbe9c29c2fc66d981de18469f6 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Fri, 17 May 2024 21:44:55 -0700
Subject: [PATCH] Validate max_tokens is number (#1445)

---
 server/utils/AiProviders/genericOpenAi/index.js        | 5 ++++-
 server/utils/agents/aibitat/plugins/summarize.js       | 1 -
 server/utils/agents/aibitat/providers/genericOpenAi.js | 5 ++++-
 server/utils/http/index.js                             | 6 ++++++
 4 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js
index dc0264e48..46b8aefba 100644
--- a/server/utils/AiProviders/genericOpenAi/index.js
+++ b/server/utils/AiProviders/genericOpenAi/index.js
@@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   handleDefaultStreamResponseV2,
 } = require("../../helpers/chat/responses");
+const { toValidNumber } = require("../../http");
 
 class GenericOpenAiLLM {
   constructor(embedder = null, modelPreference = null) {
@@ -18,7 +19,9 @@ class GenericOpenAiLLM {
     });
     this.model =
       modelPreference ?? process.env.GENERIC_OPEN_AI_MODEL_PREF ?? null;
-    this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024;
+    this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS
+      ? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024)
+      : 1024;
     if (!this.model)
       throw new Error("GenericOpenAI must have a valid model set.");
     this.limits = {
diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js
index 526de116a..de1657c9f 100644
--- a/server/utils/agents/aibitat/plugins/summarize.js
+++ b/server/utils/agents/aibitat/plugins/summarize.js
@@ -1,6 +1,5 @@
 const { Document } = require("../../../../models/documents");
 const { safeJsonParse } = require("../../../http");
-const { validate } = require("uuid");
 const { summarizeContent } = require("../utils/summarize");
 const Provider = require("../providers/ai-provider");
 
diff --git a/server/utils/agents/aibitat/providers/genericOpenAi.js b/server/utils/agents/aibitat/providers/genericOpenAi.js
index a1b2db3ea..9a753ca27 100644
--- a/server/utils/agents/aibitat/providers/genericOpenAi.js
+++ b/server/utils/agents/aibitat/providers/genericOpenAi.js
@@ -2,6 +2,7 @@ const OpenAI = require("openai");
 const Provider = require("./ai-provider.js");
 const InheritMultiple = require("./helpers/classes.js");
 const UnTooled = require("./helpers/untooled.js");
+const { toValidNumber } = require("../../../http/index.js");
 
 /**
  * The agent provider for the Generic OpenAI provider.
@@ -24,7 +25,9 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
     this._client = client;
     this.model = model;
     this.verbose = true;
-    this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024;
+    this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS
+      ? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024)
+      : 1024;
   }
 
   get client() {
diff --git a/server/utils/http/index.js b/server/utils/http/index.js
index 6400c36bc..e812b8abd 100644
--- a/server/utils/http/index.js
+++ b/server/utils/http/index.js
@@ -91,6 +91,11 @@ function isValidUrl(urlString = "") {
   return false;
 }
 
+function toValidNumber(number = null, fallback = null) {
+  if (isNaN(Number(number))) return fallback;
+  return Number(number);
+}
+
 module.exports = {
   reqBody,
   multiUserMode,
@@ -101,4 +106,5 @@ module.exports = {
   parseAuthHeader,
   safeJsonParse,
   isValidUrl,
+  toValidNumber,
 };
-- 
GitLab