diff --git a/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx b/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx index 0c085aaf62416fb32d0906c9dd3823985f54905b..67ebfe720470abd7ca336d5aad0c8d6b4bcf89df 100644 --- a/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx +++ b/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx @@ -49,11 +49,16 @@ export default function LoRATrainingRunButton({ return false; }); let modelInLocalList = false; + if (model === "unknown") + { + modelInLocalList = true; + } else { models_downloaded.forEach(modelData => { if (modelData.model_id == model || modelData.local_path === model) { modelInLocalList = true; } }); + } const datasets_downloaded = await fetch( chatAPI.Endpoints.Dataset.LocalList() diff --git a/src/renderer/components/Experiment/Train/TrainLoRA.tsx b/src/renderer/components/Experiment/Train/TrainLoRA.tsx index 83ff9d34a4ed24c2e03e22fe1773632cc8e9952e..8082caa661b34c68c7d867b1eaf39f10582a3d2d 100644 --- a/src/renderer/components/Experiment/Train/TrainLoRA.tsx +++ b/src/renderer/components/Experiment/Train/TrainLoRA.tsx @@ -63,7 +63,7 @@ function formatTemplateConfig(config): ReactElement { const r = ( <> - <b>Model:</b> {short_model_name} <br /> + {short_model_name && (<><b>Model:</b> {short_model_name} <br /></>)} <b>Dataset:</b> {c.dataset_name} <FileTextIcon size={14} /> <br /> {/* <b>Adaptor:</b> {c.adaptor_name} <br /> */} @@ -232,7 +232,9 @@ export default function TrainLoRA({ experimentInfo }) { }} key={plugin.uniqueId} disabled={ - !plugin.model_architectures?.includes(modelArchitecture) + plugin.model_architectures + ? !plugin.model_architectures.includes(modelArchitecture) + : false } > <ListItemDecorator> @@ -244,7 +246,7 @@ export default function TrainLoRA({ experimentInfo }) { level="body-xs" sx={{ color: 'var(--joy-palette-neutral-400)' }} > - {!plugin.model_architectures?.includes(modelArchitecture) + {plugin.model_architectures && !plugin.model_architectures.includes(modelArchitecture) ? '(Does not support this model architecture)' : ''} </Typography> diff --git a/src/renderer/components/Experiment/Train/TrainingModalLoRA.tsx b/src/renderer/components/Experiment/Train/TrainingModalLoRA.tsx index df7a4bab24c1442e8fd130bca5985fdbd9ed1d13..7d71b314bf551d8d7534f5b0c83f266b643d3b04 100644 --- a/src/renderer/components/Experiment/Train/TrainingModalLoRA.tsx +++ b/src/renderer/components/Experiment/Train/TrainingModalLoRA.tsx @@ -69,6 +69,25 @@ export default function TrainingModalLoRA({ const [nameInput, setNameInput] = useState(''); const [currentTab, setCurrentTab] = useState(0); + + // Fetch training type with useSWR + const { data: trainingTypeData } = useSWR( + experimentInfo?.id + ? chatAPI.Endpoints.Experiment.ScriptGetFile(experimentInfo.id, pluginId, 'index.json') + : null, + fetcher + ); + + let trainingType = "LoRA"; + if (trainingTypeData && trainingTypeData !== "undefined" && trainingTypeData.length > 0) { + + trainingType = JSON.parse(trainingTypeData)?.train_type || "LoRA"; + + } + + + + // Fetch available datasets from the API const { data: datasets, @@ -252,7 +271,7 @@ export default function TrainingModalLoRA({ template_id, event.currentTarget.elements['template_name'].value, 'Description', - 'LoRA', + trainingType, JSON.stringify(formJson) ); templateMutate(); //Need to mutate template data after updating @@ -260,7 +279,7 @@ export default function TrainingModalLoRA({ chatAPI.saveTrainingTemplate( event.currentTarget.elements['template_name'].value, 'Description', - 'LoRA', + trainingType, JSON.stringify(formJson) ); } diff --git a/src/renderer/components/Nav/Sidebar.tsx b/src/renderer/components/Nav/Sidebar.tsx index d996748ea952b341ac959212d468ebcf2959d155..0b6671ab81fdac71e784cc41ef20e4579aeb2e0f 100644 --- a/src/renderer/components/Nav/Sidebar.tsx +++ b/src/renderer/components/Nav/Sidebar.tsx @@ -164,9 +164,7 @@ export default function Sidebar({ title="Train" path="/projects/training" icon={<GraduationCapIcon />} - disabled={ - !experimentInfo?.name || !experimentInfo?.config?.foundation - } + disabled={!experimentInfo?.name} /> <SubNavItem title="Export"