Skip to content
Snippets Groups Projects
Unverified Commit a1f92525 authored by Deep Gandhi's avatar Deep Gandhi Committed by GitHub
Browse files

Merge pull request #301 from transformerlab/fix/training-others

Removing Restricted Logic for LoRA Train Type
parents eb228e99 3453d421
No related branches found
No related tags found
No related merge requests found
...@@ -49,11 +49,16 @@ export default function LoRATrainingRunButton({ ...@@ -49,11 +49,16 @@ export default function LoRATrainingRunButton({
return false; return false;
}); });
let modelInLocalList = false; let modelInLocalList = false;
if (model === "unknown")
{
modelInLocalList = true;
} else {
models_downloaded.forEach(modelData => { models_downloaded.forEach(modelData => {
if (modelData.model_id == model || modelData.local_path === model) { if (modelData.model_id == model || modelData.local_path === model) {
modelInLocalList = true; modelInLocalList = true;
} }
}); });
}
const datasets_downloaded = await fetch( const datasets_downloaded = await fetch(
chatAPI.Endpoints.Dataset.LocalList() chatAPI.Endpoints.Dataset.LocalList()
......
...@@ -63,7 +63,7 @@ function formatTemplateConfig(config): ReactElement { ...@@ -63,7 +63,7 @@ function formatTemplateConfig(config): ReactElement {
const r = ( const r = (
<> <>
<b>Model:</b> {short_model_name} <br /> {short_model_name && (<><b>Model:</b> {short_model_name} <br /></>)}
<b>Dataset:</b> {c.dataset_name} <FileTextIcon size={14} /> <b>Dataset:</b> {c.dataset_name} <FileTextIcon size={14} />
<br /> <br />
{/* <b>Adaptor:</b> {c.adaptor_name} <br /> */} {/* <b>Adaptor:</b> {c.adaptor_name} <br /> */}
...@@ -232,7 +232,9 @@ export default function TrainLoRA({ experimentInfo }) { ...@@ -232,7 +232,9 @@ export default function TrainLoRA({ experimentInfo }) {
}} }}
key={plugin.uniqueId} key={plugin.uniqueId}
disabled={ disabled={
!plugin.model_architectures?.includes(modelArchitecture) plugin.model_architectures
? !plugin.model_architectures.includes(modelArchitecture)
: false
} }
> >
<ListItemDecorator> <ListItemDecorator>
...@@ -244,7 +246,7 @@ export default function TrainLoRA({ experimentInfo }) { ...@@ -244,7 +246,7 @@ export default function TrainLoRA({ experimentInfo }) {
level="body-xs" level="body-xs"
sx={{ color: 'var(--joy-palette-neutral-400)' }} sx={{ color: 'var(--joy-palette-neutral-400)' }}
> >
{!plugin.model_architectures?.includes(modelArchitecture) {plugin.model_architectures && !plugin.model_architectures.includes(modelArchitecture)
? '(Does not support this model architecture)' ? '(Does not support this model architecture)'
: ''} : ''}
</Typography> </Typography>
......
...@@ -69,6 +69,25 @@ export default function TrainingModalLoRA({ ...@@ -69,6 +69,25 @@ export default function TrainingModalLoRA({
const [nameInput, setNameInput] = useState(''); const [nameInput, setNameInput] = useState('');
const [currentTab, setCurrentTab] = useState(0); const [currentTab, setCurrentTab] = useState(0);
// Fetch training type with useSWR
const { data: trainingTypeData } = useSWR(
experimentInfo?.id
? chatAPI.Endpoints.Experiment.ScriptGetFile(experimentInfo.id, pluginId, 'index.json')
: null,
fetcher
);
let trainingType = "LoRA";
if (trainingTypeData && trainingTypeData !== "undefined" && trainingTypeData.length > 0) {
trainingType = JSON.parse(trainingTypeData)?.train_type || "LoRA";
}
// Fetch available datasets from the API // Fetch available datasets from the API
const { const {
data: datasets, data: datasets,
...@@ -252,7 +271,7 @@ export default function TrainingModalLoRA({ ...@@ -252,7 +271,7 @@ export default function TrainingModalLoRA({
template_id, template_id,
event.currentTarget.elements['template_name'].value, event.currentTarget.elements['template_name'].value,
'Description', 'Description',
'LoRA', trainingType,
JSON.stringify(formJson) JSON.stringify(formJson)
); );
templateMutate(); //Need to mutate template data after updating templateMutate(); //Need to mutate template data after updating
...@@ -260,7 +279,7 @@ export default function TrainingModalLoRA({ ...@@ -260,7 +279,7 @@ export default function TrainingModalLoRA({
chatAPI.saveTrainingTemplate( chatAPI.saveTrainingTemplate(
event.currentTarget.elements['template_name'].value, event.currentTarget.elements['template_name'].value,
'Description', 'Description',
'LoRA', trainingType,
JSON.stringify(formJson) JSON.stringify(formJson)
); );
} }
......
...@@ -164,9 +164,7 @@ export default function Sidebar({ ...@@ -164,9 +164,7 @@ export default function Sidebar({
title="Train" title="Train"
path="/projects/training" path="/projects/training"
icon={<GraduationCapIcon />} icon={<GraduationCapIcon />}
disabled={ disabled={!experimentInfo?.name}
!experimentInfo?.name || !experimentInfo?.config?.foundation
}
/> />
<SubNavItem <SubNavItem
title="Export" title="Export"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment