From e9b87ef09b4bbcb571a3adac4ce25b567625280b Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Wed, 31 Jan 2024 16:42:30 +0700
Subject: [PATCH] feat(create-llama) add folder selection & support more
 context data types (#489)

---
 .changeset/honest-dragons-cough.md            |   5 +
 packages/create-llama/helpers/index.ts        |  57 ++--
 packages/create-llama/helpers/python.ts       |  21 +-
 packages/create-llama/helpers/types.ts        |   6 +-
 packages/create-llama/index.ts                |   7 +
 packages/create-llama/questions.ts            | 258 ++++++++++++------
 .../types/simple/fastapi/pyproject.toml       |   2 +-
 .../types/streaming/fastapi/pyproject.toml    |   2 +-
 8 files changed, 234 insertions(+), 124 deletions(-)
 create mode 100644 .changeset/honest-dragons-cough.md

diff --git a/.changeset/honest-dragons-cough.md b/.changeset/honest-dragons-cough.md
new file mode 100644
index 000000000..6c18a116a
--- /dev/null
+++ b/.changeset/honest-dragons-cough.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add folder selection and support more file types
diff --git a/packages/create-llama/helpers/index.ts b/packages/create-llama/helpers/index.ts
index 51d917f49..7415929d5 100644
--- a/packages/create-llama/helpers/index.ts
+++ b/packages/create-llama/helpers/index.ts
@@ -12,6 +12,7 @@ import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
 import { installPythonTemplate } from "./python";
 import { downloadAndExtractRepo } from "./repo";
 import {
+  FileSourceConfig,
   InstallTemplateArgs,
   TemplateDataSource,
   TemplateFramework,
@@ -120,28 +121,41 @@ const installDependencies = async (
   }
 };
 
-const copyContextData = async (root: string, contextFile?: string) => {
+const copyContextData = async (
+  root: string,
+  dataSource?: TemplateDataSource,
+) => {
   const destPath = path.join(root, "data");
-  if (contextFile) {
-    console.log(`\nCopying provided file to ${cyan(destPath)}\n`);
-    await fs.mkdir(destPath, { recursive: true });
-    await fs.copyFile(
-      contextFile,
-      path.join(destPath, path.basename(contextFile)),
-    );
-  } else {
-    const srcPath = path.join(
-      __dirname,
-      "..",
-      "templates",
-      "components",
-      "data",
-    );
-    console.log(`\nCopying test data to ${cyan(destPath)}\n`);
+
+  let dataSourceConfig = dataSource?.config as FileSourceConfig;
+
+  // Copy file
+  if (dataSource?.type === "file") {
+    if (dataSourceConfig.path) {
+      console.log(`\nCopying file to ${cyan(destPath)}\n`);
+      await fs.mkdir(destPath, { recursive: true });
+      await fs.copyFile(
+        dataSourceConfig.path,
+        path.join(destPath, path.basename(dataSourceConfig.path)),
+      );
+    } else {
+      console.log("Missing file path in config");
+      process.exit(1);
+    }
+    return;
+  }
+
+  // Copy folder
+  if (dataSource?.type === "folder") {
+    let srcPath =
+      dataSourceConfig.path ??
+      path.join(__dirname, "..", "templates", "components", "data");
+    console.log(`\nCopying data to ${cyan(destPath)}\n`);
     await copy("**", destPath, {
       parents: true,
       cwd: srcPath,
     });
+    return;
   }
 };
 
@@ -192,14 +206,7 @@ export const installTemplate = async (
     });
 
     if (props.engine === "context") {
-      if (
-        props.dataSource?.type === "file" &&
-        "contextFile" in props.dataSource.config
-      ) {
-        await copyContextData(props.root, props.dataSource.config.contextFile);
-      } else {
-        await copyContextData(props.root);
-      }
+      await copyContextData(props.root, props.dataSource);
       await installDependencies(
         props.framework,
         props.packageManager,
diff --git a/packages/create-llama/helpers/python.ts b/packages/create-llama/helpers/python.ts
index ca5f344db..7fa006427 100644
--- a/packages/create-llama/helpers/python.ts
+++ b/packages/create-llama/helpers/python.ts
@@ -169,6 +169,8 @@ export const installPythonTemplate = async ({
 
   if (engine === "context") {
     const compPath = path.join(__dirname, "..", "templates", "components");
+
+    // Copy engine code
     let vectorDbDirName = vectorDb ?? "none";
     const VectorDBPath = path.join(
       compPath,
@@ -177,17 +179,22 @@ export const installPythonTemplate = async ({
       vectorDbDirName,
     );
     const enginePath = path.join(root, "app", "engine");
-
     await copy("**", path.join(root, "app", "engine"), {
       parents: true,
       cwd: VectorDBPath,
     });
-    let dataSourceDir = dataSource?.type ?? "file";
-    const loaderPath = path.join(compPath, "loaders", "python", dataSourceDir);
-    await copy("**", enginePath, {
-      parents: true,
-      cwd: loaderPath,
-    });
+
+    const dataSourceType = dataSource?.type;
+    if (dataSourceType !== undefined && dataSourceType !== "none") {
+      let loaderPath =
+        dataSourceType === "folder"
+          ? path.join(compPath, "loaders", "python", "file")
+          : path.join(compPath, "loaders", "python", dataSourceType);
+      await copy("**", enginePath, {
+        parents: true,
+        cwd: loaderPath,
+      });
+    }
   }
 
   const addOnDependencies = getAdditionalDependencies(vectorDb);
diff --git a/packages/create-llama/helpers/types.ts b/packages/create-llama/helpers/types.ts
index e26608609..191e028f4 100644
--- a/packages/create-llama/helpers/types.ts
+++ b/packages/create-llama/helpers/types.ts
@@ -7,11 +7,13 @@ export type TemplateUI = "html" | "shadcn";
 export type TemplateVectorDB = "none" | "mongo" | "pg";
 export type TemplatePostInstallAction = "none" | "dependencies" | "runApp";
 export type TemplateDataSource = {
-  type: "none" | "file" | "web";
+  type: TemplateDataSourceType;
   config: TemplateDataSourceConfig;
 };
+export type TemplateDataSourceType = "none" | "file" | "folder" | "web";
+// Config for both file and folder
 export type FileSourceConfig = {
-  contextFile?: string;
+  path?: string;
 };
 export type WebSourceConfig = {
   baseUrl?: string;
diff --git a/packages/create-llama/index.ts b/packages/create-llama/index.ts
index 26592ec1a..601742c2c 100644
--- a/packages/create-llama/index.ts
+++ b/packages/create-llama/index.ts
@@ -83,6 +83,13 @@ const program = new Commander.Command(packageJson.name)
     `
 
   Select a framework to bootstrap the application with.
+`,
+  )
+  .option(
+    "--files <path>",
+    `
+  
+    Specify the path to a local file or folder for chatting.
 `,
   )
   .option(
diff --git a/packages/create-llama/questions.ts b/packages/create-llama/questions.ts
index eff974556..59aa51515 100644
--- a/packages/create-llama/questions.ts
+++ b/packages/create-llama/questions.ts
@@ -5,19 +5,35 @@ import path from "path";
 import { blue, green, red } from "picocolors";
 import prompts from "prompts";
 import { InstallAppArgs } from "./create-app";
-import { TemplateFramework } from "./helpers";
+import { TemplateDataSourceType, TemplateFramework } from "./helpers";
 import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
 import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
 import { getRepoRootFolders } from "./helpers/repo";
 
-export type QuestionArgs = Omit<InstallAppArgs, "appPath" | "packageManager">;
+export type QuestionArgs = Omit<
+  InstallAppArgs,
+  "appPath" | "packageManager"
+> & { files?: string };
+const supportedContextFileTypes = [
+  ".pdf",
+  ".doc",
+  ".docx",
+  ".xls",
+  ".xlsx",
+  ".csv",
+];
 const MACOS_FILE_SELECTION_SCRIPT = `
 osascript -l JavaScript -e '
   a = Application.currentApplication();
   a.includeStandardAdditions = true;
   a.chooseFile({ withPrompt: "Please select a file to process:" }).toString()
 '`;
-
+const MACOS_FOLDER_SELECTION_SCRIPT = `
+osascript -l JavaScript -e '
+  a = Application.currentApplication();
+  a.includeStandardAdditions = true;
+  a.chooseFolder({ withPrompt: "Please select a folder to process:" }).toString()
+'`;
 const WINDOWS_FILE_SELECTION_SCRIPT = `
 Add-Type -AssemblyName System.Windows.Forms
 $openFileDialog = New-Object System.Windows.Forms.OpenFileDialog
@@ -27,6 +43,15 @@ if ($result -eq 'OK') {
   $openFileDialog.FileName
 }
 `;
+const WINDOWS_FOLDER_SELECTION_SCRIPT = `
+Add-Type -AssemblyName System.windows.forms
+$folderBrowser = New-Object System.Windows.Forms.FolderBrowserDialog
+$dialogResult = $folderBrowser.ShowDialog()
+if ($dialogResult -eq [System.Windows.Forms.DialogResult]::OK)
+{
+    $folderBrowser.SelectedPath
+}
+`;
 
 const defaults: QuestionArgs = {
   template: "streaming",
@@ -78,39 +103,70 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
   return displayedChoices;
 };
 
-const selectPDFFile = async () => {
-  // Popup to select a PDF file
+const getDataSourceChoices = (framework: TemplateFramework) => {
+  let choices = [
+    {
+      title: "No data, just a simple chat",
+      value: "simple",
+    },
+    { title: "Use an example PDF", value: "exampleFile" },
+  ];
+  if (process.platform === "win32" || process.platform === "darwin") {
+    choices.push({
+      title: `Use a local file (${supportedContextFileTypes})`,
+      value: "localFile",
+    });
+    choices.push({
+      title: `Use a local folder`,
+      value: "localFolder",
+    });
+  }
+  if (framework === "fastapi") {
+    choices.push({ title: "Use website content", value: "web" });
+  }
+  return choices;
+};
+
+const selectLocalContextData = async (type: TemplateDataSourceType) => {
   try {
-    let selectedFilePath: string = "";
+    let selectedPath: string = "";
+    let execScript: string;
+    let execOpts: any = {};
     switch (process.platform) {
       case "win32": // Windows
-        selectedFilePath = execSync(WINDOWS_FILE_SELECTION_SCRIPT, {
-          shell: "powershell.exe",
-        })
-          .toString()
-          .trim();
+        execScript =
+          type === "file"
+            ? WINDOWS_FILE_SELECTION_SCRIPT
+            : WINDOWS_FOLDER_SELECTION_SCRIPT;
+        execOpts = { shell: "powershell.exe" };
         break;
       case "darwin": // MacOS
-        selectedFilePath = execSync(MACOS_FILE_SELECTION_SCRIPT)
-          .toString()
-          .trim();
+        execScript =
+          type === "file"
+            ? MACOS_FILE_SELECTION_SCRIPT
+            : MACOS_FOLDER_SELECTION_SCRIPT;
         break;
       default: // Unsupported OS
         console.log(red("Unsupported OS error!"));
         process.exit(1);
     }
-    // Check is pdf file
-    if (!selectedFilePath.endsWith(".pdf")) {
-      console.log(
-        red("Unsupported file error! Please select a valid PDF file!"),
-      );
-      process.exit(1);
+    selectedPath = execSync(execScript, execOpts).toString().trim();
+    if (type === "file") {
+      let fileType = path.extname(selectedPath);
+      if (!supportedContextFileTypes.includes(fileType)) {
+        console.log(
+          red(
+            `Please select a supported file type: ${supportedContextFileTypes}`,
+          ),
+        );
+        process.exit(1);
+      }
     }
-    return selectedFilePath;
+    return selectedPath;
   } catch (error) {
     console.log(
       red(
-        "Got error when trying to select file! Please try again or select other options.",
+        "Got an error when trying to select local context data! Please try again or select another data source option.",
       ),
     );
     process.exit(1);
@@ -369,30 +425,32 @@ export const askQuestions = async (
     }
   }
 
+  if (program.files) {
+    // If user specified files option, then the program should use context engine
+    program.engine == "context";
+    if (!fs.existsSync(program.files)) {
+      console.log("File or folder not found");
+      process.exit(1);
+    } else {
+      program.dataSource = {
+        type: fs.lstatSync(program.files).isDirectory() ? "folder" : "file",
+        config: {
+          path: program.files,
+        },
+      };
+    }
+  }
+
   if (!program.engine) {
     if (ciInfo.isCI) {
       program.engine = getPrefOrDefault("engine");
     } else {
-      let choices = [
-        {
-          title: "No data, just a simple chat",
-          value: "simple",
-        },
-        { title: "Use an example PDF", value: "exampleFile" },
-      ];
-      if (process.platform === "win32" || process.platform === "darwin") {
-        choices.push({ title: "Use a local PDF file", value: "localFile" });
-      }
-      if (program.framework === "fastapi") {
-        choices.push({ title: "Use website content", value: "web" });
-      }
-
       const { dataSource } = await prompts(
         {
           type: "select",
           name: "dataSource",
           message: "Which data source would you like to use?",
-          choices: choices,
+          choices: getDataSourceChoices(program.framework),
           initial: 1,
         },
         handlers,
@@ -403,18 +461,29 @@ export const askQuestions = async (
         switch (dataSource) {
           case "simple":
             program.engine = "simple";
+            program.dataSource = { type: "none", config: {} };
             break;
           case "exampleFile":
             program.engine = "context";
-            // example file is a context app with dataSource.type = file but has no config
-            program.dataSource = { type: "file", config: {} };
+            // Treat example as a folder data source with no config
+            program.dataSource = { type: "folder", config: {} };
             break;
           case "localFile":
             program.engine = "context";
-            program.dataSource.type = "file";
-            // If the user selected the "pdf" option, ask them to select a file
-            program.dataSource.config = {
-              contextFile: await selectPDFFile(),
+            program.dataSource = {
+              type: "file",
+              config: {
+                path: await selectLocalContextData("file"),
+              },
+            };
+            break;
+          case "localFolder":
+            program.engine = "context";
+            program.dataSource = {
+              type: "folder",
+              config: {
+                path: await selectLocalContextData("folder"),
+              },
             };
             break;
           case "web":
@@ -424,56 +493,69 @@ export const askQuestions = async (
         }
       }
     }
+  } else if (!program.dataSource) {
+    // Handle a case when engine is specified but dataSource is not
+    if (program.engine === "context") {
+      program.dataSource = {
+        type: "folder",
+        config: {},
+      };
+    } else if (program.engine === "simple") {
+      program.dataSource = {
+        type: "none",
+        config: {},
+      };
+    }
+  }
+
+  if (program.dataSource?.type === "web" && program.framework === "fastapi") {
+    let { baseUrl } = await prompts(
+      {
+        type: "text",
+        name: "baseUrl",
+        message: "Please provide base URL of the website:",
+        initial: "https://www.llamaindex.ai",
+      },
+      handlers,
+    );
+    try {
+      if (!baseUrl.includes("://")) {
+        baseUrl = `https://${baseUrl}`;
+      }
+      let checkUrl = new URL(baseUrl);
+      if (checkUrl.protocol !== "https:" && checkUrl.protocol !== "http:") {
+        throw new Error("Invalid protocol");
+      }
+    } catch (error) {
+      console.log(
+        red(
+          "Invalid URL provided! Please provide a valid URL (e.g. https://www.llamaindex.ai)",
+        ),
+      );
+      process.exit(1);
+    }
+    program.dataSource.config = {
+      baseUrl: baseUrl,
+      depth: 1,
+    };
+  }
 
-    if (program.dataSource?.type === "web" && program.framework === "fastapi") {
-      let { baseUrl } = await prompts(
+  if (!program.engine && program.engine !== "simple" && !program.vectorDb) {
+    if (ciInfo.isCI) {
+      program.vectorDb = getPrefOrDefault("vectorDb");
+    } else {
+      const { vectorDb } = await prompts(
         {
-          type: "text",
-          name: "baseUrl",
-          message: "Please provide base URL of the website:",
-          initial: "https://www.llamaindex.ai",
+          type: "select",
+          name: "vectorDb",
+          message: "Would you like to use a vector database?",
+          choices: getVectorDbChoices(program.framework),
+          initial: 0,
         },
         handlers,
       );
-      try {
-        if (!baseUrl.includes("://")) {
-          baseUrl = `https://${baseUrl}`;
-        }
-        let checkUrl = new URL(baseUrl);
-        if (checkUrl.protocol !== "https:" && checkUrl.protocol !== "http:") {
-          throw new Error("Invalid protocol");
-        }
-      } catch (error) {
-        console.log(
-          red(
-            "Invalid URL provided! Please provide a valid URL (e.g. https://www.llamaindex.ai)",
-          ),
-        );
-        process.exit(1);
-      }
-      program.dataSource.config = {
-        baseUrl: baseUrl,
-        depth: 1,
-      };
-    }
-
-    if (program.engine !== "simple" && !program.vectorDb) {
-      if (ciInfo.isCI) {
-        program.vectorDb = getPrefOrDefault("vectorDb");
-      } else {
-        const { vectorDb } = await prompts(
-          {
-            type: "select",
-            name: "vectorDb",
-            message: "Would you like to use a vector database?",
-            choices: getVectorDbChoices(program.framework),
-            initial: 0,
-          },
-          handlers,
-        );
-        program.vectorDb = vectorDb;
-        preferences.vectorDb = vectorDb;
-      }
+      program.vectorDb = vectorDb;
+      preferences.vectorDb = vectorDb;
     }
   }
 
diff --git a/packages/create-llama/templates/types/simple/fastapi/pyproject.toml b/packages/create-llama/templates/types/simple/fastapi/pyproject.toml
index f9bb9605b..d1952f4a2 100644
--- a/packages/create-llama/templates/types/simple/fastapi/pyproject.toml
+++ b/packages/create-llama/templates/types/simple/fastapi/pyproject.toml
@@ -12,7 +12,7 @@ uvicorn = { extras = ["standard"], version = "^0.23.2" }
 llama-index = "^0.9.19"
 pypdf = "^3.17.0"
 python-dotenv = "^1.0.0"
-
+docx2txt = "^0.8"
 
 [build-system]
 requires = ["poetry-core"]
diff --git a/packages/create-llama/templates/types/streaming/fastapi/pyproject.toml b/packages/create-llama/templates/types/streaming/fastapi/pyproject.toml
index f9bb9605b..d1952f4a2 100644
--- a/packages/create-llama/templates/types/streaming/fastapi/pyproject.toml
+++ b/packages/create-llama/templates/types/streaming/fastapi/pyproject.toml
@@ -12,7 +12,7 @@ uvicorn = { extras = ["standard"], version = "^0.23.2" }
 llama-index = "^0.9.19"
 pypdf = "^3.17.0"
 python-dotenv = "^1.0.0"
-
+docx2txt = "^0.8"
 
 [build-system]
 requires = ["poetry-core"]
-- 
GitLab