Skip to content
Snippets Groups Projects
Unverified Commit 0b4e566e authored by Tony Salomone's avatar Tony Salomone Committed by GitHub
Browse files

Merge pull request #215 from transformerlab/block-train-job-if-requirements-not-met

Block train start if either the dataset or the model is not downloaded, and download the necessary pieces instead
parents 0a13afc6 6551823b
No related branches found
No related tags found
No related merge requests found
......@@ -21,20 +21,131 @@ export default function LoRATrainingRunButton({
color="primary"
endDecorator={<PlayIcon size="14px" />}
onClick={async () => {
// Use fetch API to call endpoint
await fetch(
chatAPI.Endpoints.Jobs.Create(
experimentId,
'TRAIN',
'QUEUED',
JSON.stringify(job_data)
const model = job_data.model_name;
console.log(job_data)
const dataset = job_data.dataset;
const models_downloaded = await fetch(
chatAPI.Endpoints.Models.LocalList()
).then((response) => {
// First check that the API responded correctly
if (response.ok) {
return response.json();
} else {
const error_msg = `${response.statusText}`;
throw new Error(error_msg);
}
})
.then((data) => {
// Then check the API responose to see if there was an error.
console.log('Server response:', data);
if (data?.status == "error") {
throw new Error(data.message);
}
return data;
})
.catch((error) => {
alert(error);
return false;
});
let modelInLocalList = false;
models_downloaded.forEach(modelData => {
if (modelData.model_id == model) {
modelInLocalList = true;
}
});
const datasets_downloaded = await fetch(
chatAPI.Endpoints.Dataset.LocalList()
).then((response) => {
// First check that the API responded correctly
if (response.ok) {
return response.json();
} else {
const error_msg = `${response.statusText}`;
throw new Error(error_msg);
}
})
.then((data) => {
// Then check the API responose to see if there was an error.
console.log('Server response:', data);
if (data?.status == "error") {
throw new Error(data.message);
}
return data;
})
.catch((error) => {
alert(error);
return false;
});
let datasetInLocalList = false;
datasets_downloaded.forEach(datasetData => {
if (datasetData.dataset_id == dataset) {
datasetInLocalList = true;
}
});
if(modelInLocalList && datasetInLocalList){
// Use fetch API to call endpoint
await fetch(
chatAPI.Endpoints.Jobs.Create(
experimentId,
'TRAIN',
'QUEUED',
JSON.stringify(job_data)
)
)
)
.then((response) => response.json())
.then((data) => console.log(data))
.catch((error) => console.log(error));
jobsMutate();
}}
.then((response) => response.json())
.then((data) => console.log(data))
.catch((error) => console.log(error));
jobsMutate();
}
else{
let msg = "Warning: To use this recipe you will need to download the following:";
let shouldDownload = false;
if (!datasetInLocalList) {
msg += "\n- Dataset: " + dataset;
}
if (!modelInLocalList) {
msg += "\n- Model: " + model;
}
msg += "\n\nDo you want to download these now?";
if (confirm(msg)) { // Use confirm() to get Accept/Cancel
if (!datasetInLocalList) {
fetch(chatAPI.Endpoints.Dataset.Download(dataset))
.then((response) => {
if (!response.ok) {
console.log(response);
throw new Error(`HTTP Status: ${response.status}`);
}
return response.json();
})
.catch((error) => {
alert('Dataset download failed:\n' + error);
});
}
if (!modelInLocalList) {
chatAPI.downloadModelFromHuggingFace(model)
.then((response) => {
if (response.status == "error") {
console.log(response);
throw new Error(`${response.message}`);
}
return response;
})
.catch((error) => {
alert('Model download failed:\n' + error);
});
}
} else {
// User pressed Cancel
alert("Downloads cancelled. This recipe might not work correctly.");
}
}
}
}
>
{initialMessage}
</Button>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment