From 409d57c8fae3622673c8528cfdd6775b1d65d3ef Mon Sep 17 00:00:00 2001
From: Tony Salomone <dadmobile@gmail.com>
Date: Fri, 17 Jan 2025 17:13:11 -0500
Subject: [PATCH] Update ImportModelBar to use a jobID so its downloads get a
 progress bar

---
 .../components/ModelZoo/ImportModelsBar.tsx   | 44 ++++++++++++-------
 .../components/ModelZoo/ModelStore.tsx        |  6 +--
 2 files changed, 32 insertions(+), 18 deletions(-)

diff --git a/src/renderer/components/ModelZoo/ImportModelsBar.tsx b/src/renderer/components/ModelZoo/ImportModelsBar.tsx
index 2f83ce28..d8d17c58 100644
--- a/src/renderer/components/ModelZoo/ImportModelsBar.tsx
+++ b/src/renderer/components/ModelZoo/ImportModelsBar.tsx
@@ -13,10 +13,10 @@ import { PlusIcon } from 'lucide-react';
 import * as chatAPI from '../../lib/transformerlab-api-sdk';
 import ImportModelsModal from './ImportModelsModal';
 
-// Needs to share currentlyDownloading with ModelsStore
+// Needs to share jobId with ModelsStore
 // If you start a download on one it should stop you from starting on the other
 // Also this is how the import bar tells teh model store to show a download progress bar
-export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDownloading }) {
+export default function ImportModelsBar({ jobId, setJobId }) {
     const [importModelsModalOpen, setImportModelsModalOpen] = useState(false);
 
     return (
@@ -57,28 +57,42 @@ export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDown
 
                       // only download if valid model is entered
                       if (model) {
-                        // this triggers UI changes while download is in progress
-                        setCurrentlyDownloading(model);
+                        setJobId(-1);
+                        try {
+                          const jobResponse = await fetch(
+                            chatAPI.Endpoints.Jobs.Create()
+                          );
+                          const newJobId = await jobResponse.json();
+                          setJobId(newJobId);
 
-                        // Try downloading the model
-                        const response = await chatAPI.downloadModelFromHuggingFace(model);
-                        if (response?.status == 'error') {
-                          alert('Download failed!\n' + response.message);
-                        }
+                          // Try downloading the model
+                          const response = await chatAPI.downloadModelFromHuggingFace(
+                            model,
+                            newJobId
+                          );
+                          console.log(response);
+                          if (response?.status == 'error') {
+                            alert('Download failed!\n' + response.message);
+                          }
+
+                          // download complete
+                          setJobId(null);
 
-                        // download complete
-                        setCurrentlyDownloading(null);
-                        //modelGalleryMutate();
+                        } catch (e) {
+                          setJobId(null);
+                          console.log(e);
+                          return alert('Failed to download');
+                        }
                       }
                     }}
                 startDecorator={
-                  currentlyDownloading ? (
+                  jobId ? (
                     <CircularProgress size="sm" thickness={2} />
                   ) : (
                     ""
                   )}
                   >
-                  {currentlyDownloading ? (
+                  {jobId ? (
                     "Downloading"
                   ) : (
                     "Download 🤗 Model"
@@ -86,7 +100,7 @@ export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDown
                   </Button>
                 }
                 sx={{ width: '500px' }}
-                disabled={currentlyDownloading}
+                disabled={jobId != null}
               />
             </FormControl>
             <Button
diff --git a/src/renderer/components/ModelZoo/ModelStore.tsx b/src/renderer/components/ModelZoo/ModelStore.tsx
index 49213b63..4eace97f 100644
--- a/src/renderer/components/ModelZoo/ModelStore.tsx
+++ b/src/renderer/components/ModelZoo/ModelStore.tsx
@@ -117,7 +117,7 @@ export default function ModelStore() {
   } = useSWR(chatAPI.Endpoints.Models.Gallery(), fetcher);
 
   const { data: modelDownloadProgress } = useSWR(
-    currentlyDownloading && jobId && jobId != '-1'
+    jobId && jobId != '-1'
       ? chatAPI.Endpoints.Jobs.Get(jobId)
       : null,
     fetcher,
@@ -606,8 +606,8 @@ export default function ModelStore() {
         </Table>
       </Sheet>
       <ImportModelsBar
-        currentlyDownloading={currentlyDownloading}
-        setCurrentlyDownloading={setCurrentlyDownloading}
+        jobId={jobId}
+        setJobId={setJobId}
       />
     </Sheet>
   );
-- 
GitLab