From 2b8aaa835dd1dec68b2dee38d5f16f75dd984eba Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 4 Nov 2024 16:39:27 +0700
Subject: [PATCH] Add support for local models via Hugging Face (#414)

---
 .changeset/plenty-pumpkins-fold.md            |  5 ++
 helpers/env-variables.ts                      | 14 +++++
 helpers/providers/huggingface.ts              | 61 +++++++++++++++++++
 helpers/providers/index.ts                    |  5 ++
 helpers/python.ts                             | 15 +++++
 helpers/types.ts                              |  1 +
 .../components/settings/python/settings.py    | 38 ++++++++++++
 7 files changed, 139 insertions(+)
 create mode 100644 .changeset/plenty-pumpkins-fold.md
 create mode 100644 helpers/providers/huggingface.ts

diff --git a/.changeset/plenty-pumpkins-fold.md b/.changeset/plenty-pumpkins-fold.md
new file mode 100644
index 00000000..3d18a1a8
--- /dev/null
+++ b/.changeset/plenty-pumpkins-fold.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add support for local models via Hugging Face
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index f1c23de2..07ae88e0 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -336,6 +336,20 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
           },
         ]
       : []),
+    ...(modelConfig.provider === "huggingface"
+      ? [
+          {
+            name: "EMBEDDING_BACKEND",
+            description:
+              "The backend to use for the Sentence Transformers embedding model, either 'torch', 'onnx', or 'openvino'. Defaults to 'onnx'.",
+          },
+          {
+            name: "EMBEDDING_TRUST_REMOTE_CODE",
+            description:
+              "Whether to trust remote code for the embedding model, required for some models with custom code.",
+          },
+        ]
+      : []),
     ...(modelConfig.provider === "t-systems"
       ? [
           {
diff --git a/helpers/providers/huggingface.ts b/helpers/providers/huggingface.ts
new file mode 100644
index 00000000..1a3a4a06
--- /dev/null
+++ b/helpers/providers/huggingface.ts
@@ -0,0 +1,61 @@
+import prompts from "prompts";
+import { ModelConfigParams } from ".";
+import { questionHandlers, toChoice } from "../../questions/utils";
+
+const MODELS = ["HuggingFaceH4/zephyr-7b-alpha"];
+type ModelData = {
+  dimensions: number;
+};
+const EMBEDDING_MODELS: Record<string, ModelData> = {
+  "BAAI/bge-small-en-v1.5": { dimensions: 384 },
+};
+
+const DEFAULT_MODEL = MODELS[0];
+const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
+const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
+
+type HuggingfaceQuestionsParams = {
+  askModels: boolean;
+};
+
+export async function askHuggingfaceQuestions({
+  askModels,
+}: HuggingfaceQuestionsParams): Promise<ModelConfigParams> {
+  const config: ModelConfigParams = {
+    model: DEFAULT_MODEL,
+    embeddingModel: DEFAULT_EMBEDDING_MODEL,
+    dimensions: DEFAULT_DIMENSIONS,
+    isConfigured(): boolean {
+      return true;
+    },
+  };
+
+  if (askModels) {
+    const { model } = await prompts(
+      {
+        type: "select",
+        name: "model",
+        message: "Which Hugging Face model would you like to use?",
+        choices: MODELS.map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.model = model;
+
+    const { embeddingModel } = await prompts(
+      {
+        type: "select",
+        name: "embeddingModel",
+        message: "Which embedding model would you like to use?",
+        choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.embeddingModel = embeddingModel;
+    config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
+  }
+
+  return config;
+}
diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts
index 06977f6f..a7530298 100644
--- a/helpers/providers/index.ts
+++ b/helpers/providers/index.ts
@@ -5,6 +5,7 @@ import { askAnthropicQuestions } from "./anthropic";
 import { askAzureQuestions } from "./azure";
 import { askGeminiQuestions } from "./gemini";
 import { askGroqQuestions } from "./groq";
+import { askHuggingfaceQuestions } from "./huggingface";
 import { askLLMHubQuestions } from "./llmhub";
 import { askMistralQuestions } from "./mistral";
 import { askOllamaQuestions } from "./ollama";
@@ -39,6 +40,7 @@ export async function askModelConfig({
 
     if (framework === "fastapi") {
       choices.push({ title: "T-Systems", value: "t-systems" });
+      choices.push({ title: "Huggingface", value: "huggingface" });
     }
     const { provider } = await prompts(
       {
@@ -76,6 +78,9 @@ export async function askModelConfig({
     case "t-systems":
       modelConfig = await askLLMHubQuestions({ askModels });
       break;
+    case "huggingface":
+      modelConfig = await askHuggingfaceQuestions({ askModels });
+      break;
     default:
       modelConfig = await askOpenAIQuestions({
         openAiKey,
diff --git a/helpers/python.ts b/helpers/python.ts
index 6305739a..9dd686d8 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -234,6 +234,21 @@ const getAdditionalDependencies = (
         version: "0.2.4",
       });
       break;
+    case "huggingface":
+      dependencies.push({
+        name: "llama-index-llms-huggingface",
+        version: "^0.3.5",
+      });
+      dependencies.push({
+        name: "llama-index-embeddings-huggingface",
+        version: "^0.3.1",
+      });
+      dependencies.push({
+        name: "optimum",
+        version: "^1.23.3",
+        extras: ["onnxruntime"],
+      });
+      break;
     case "t-systems":
       dependencies.push({
         name: "llama-index-agent-openai",
diff --git a/helpers/types.ts b/helpers/types.ts
index cef8ce3b..bcaf5b06 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -9,6 +9,7 @@ export type ModelProvider =
   | "gemini"
   | "mistral"
   | "azure-openai"
+  | "huggingface"
   | "t-systems";
 export type ModelConfig = {
   provider: ModelProvider;
diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py
index 681974ce..bc7270bd 100644
--- a/templates/components/settings/python/settings.py
+++ b/templates/components/settings/python/settings.py
@@ -21,6 +21,8 @@ def init_settings():
             init_mistral()
         case "azure-openai":
             init_azure_openai()
+        case "huggingface":
+            init_huggingface()
         case "t-systems":
             from .llmhub import init_llmhub
 
@@ -138,6 +140,42 @@ def init_fastembed():
     )
 
 
+def init_huggingface_embedding():
+    try:
+        from llama_index.embeddings.huggingface import HuggingFaceEmbedding
+    except ImportError:
+        raise ImportError(
+            "Hugging Face support is not installed. Please install it with `poetry add llama-index-embeddings-huggingface`"
+        )
+
+    embedding_model = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
+    backend = os.getenv("EMBEDDING_BACKEND", "onnx")  # "torch", "onnx", or "openvino"
+    trust_remote_code = (
+        os.getenv("EMBEDDING_TRUST_REMOTE_CODE", "false").lower() == "true"
+    )
+
+    Settings.embed_model = HuggingFaceEmbedding(
+        model_name=embedding_model,
+        trust_remote_code=trust_remote_code,
+        backend=backend,
+    )
+
+
+def init_huggingface():
+    try:
+        from llama_index.llms.huggingface import HuggingFaceLLM
+    except ImportError:
+        raise ImportError(
+            "Hugging Face support is not installed. Please install it with `poetry add llama-index-llms-huggingface` and `poetry add llama-index-embeddings-huggingface`"
+        )
+
+    Settings.llm = HuggingFaceLLM(
+        model_name=os.getenv("MODEL"),
+        tokenizer_name=os.getenv("MODEL"),
+    )
+    init_huggingface_embedding()
+
+
 def init_groq():
     try:
         from llama_index.llms.groq import Groq
-- 
GitLab