From 409d57c8fae3622673c8528cfdd6775b1d65d3ef Mon Sep 17 00:00:00 2001 From: Tony Salomone <dadmobile@gmail.com> Date: Fri, 17 Jan 2025 17:13:11 -0500 Subject: [PATCH] Update ImportModelBar to use a jobID so its downloads get a progress bar --- .../components/ModelZoo/ImportModelsBar.tsx | 44 ++++++++++++------- .../components/ModelZoo/ModelStore.tsx | 6 +-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/renderer/components/ModelZoo/ImportModelsBar.tsx b/src/renderer/components/ModelZoo/ImportModelsBar.tsx index 2f83ce28..d8d17c58 100644 --- a/src/renderer/components/ModelZoo/ImportModelsBar.tsx +++ b/src/renderer/components/ModelZoo/ImportModelsBar.tsx @@ -13,10 +13,10 @@ import { PlusIcon } from 'lucide-react'; import * as chatAPI from '../../lib/transformerlab-api-sdk'; import ImportModelsModal from './ImportModelsModal'; -// Needs to share currentlyDownloading with ModelsStore +// Needs to share jobId with ModelsStore // If you start a download on one it should stop you from starting on the other // Also this is how the import bar tells teh model store to show a download progress bar -export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDownloading }) { +export default function ImportModelsBar({ jobId, setJobId }) { const [importModelsModalOpen, setImportModelsModalOpen] = useState(false); return ( @@ -57,28 +57,42 @@ export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDown // only download if valid model is entered if (model) { - // this triggers UI changes while download is in progress - setCurrentlyDownloading(model); + setJobId(-1); + try { + const jobResponse = await fetch( + chatAPI.Endpoints.Jobs.Create() + ); + const newJobId = await jobResponse.json(); + setJobId(newJobId); - // Try downloading the model - const response = await chatAPI.downloadModelFromHuggingFace(model); - if (response?.status == 'error') { - alert('Download failed!\n' + response.message); - } + // Try downloading the model + const response = await chatAPI.downloadModelFromHuggingFace( + model, + newJobId + ); + console.log(response); + if (response?.status == 'error') { + alert('Download failed!\n' + response.message); + } + + // download complete + setJobId(null); - // download complete - setCurrentlyDownloading(null); - //modelGalleryMutate(); + } catch (e) { + setJobId(null); + console.log(e); + return alert('Failed to download'); + } } }} startDecorator={ - currentlyDownloading ? ( + jobId ? ( <CircularProgress size="sm" thickness={2} /> ) : ( "" )} > - {currentlyDownloading ? ( + {jobId ? ( "Downloading" ) : ( "Download 🤗 Model" @@ -86,7 +100,7 @@ export default function ImportModelsBar({ currentlyDownloading, setCurrentlyDown </Button> } sx={{ width: '500px' }} - disabled={currentlyDownloading} + disabled={jobId != null} /> </FormControl> <Button diff --git a/src/renderer/components/ModelZoo/ModelStore.tsx b/src/renderer/components/ModelZoo/ModelStore.tsx index 49213b63..4eace97f 100644 --- a/src/renderer/components/ModelZoo/ModelStore.tsx +++ b/src/renderer/components/ModelZoo/ModelStore.tsx @@ -117,7 +117,7 @@ export default function ModelStore() { } = useSWR(chatAPI.Endpoints.Models.Gallery(), fetcher); const { data: modelDownloadProgress } = useSWR( - currentlyDownloading && jobId && jobId != '-1' + jobId && jobId != '-1' ? chatAPI.Endpoints.Jobs.Get(jobId) : null, fetcher, @@ -606,8 +606,8 @@ export default function ModelStore() { </Table> </Sheet> <ImportModelsBar - currentlyDownloading={currentlyDownloading} - setCurrentlyDownloading={setCurrentlyDownloading} + jobId={jobId} + setJobId={setJobId} /> </Sheet> ); -- GitLab