Skip to content
Snippets Groups Projects
Commit 6551823b authored by sanjaycal's avatar sanjaycal
Browse files

Block train start if either the dataset or the model is not downloaded, and then download them

parent 7364b828
Branches block-train-job-if-requirements-not-met
No related tags found
No related merge requests found
...@@ -21,20 +21,131 @@ export default function LoRATrainingRunButton({ ...@@ -21,20 +21,131 @@ export default function LoRATrainingRunButton({
color="primary" color="primary"
endDecorator={<PlayIcon size="14px" />} endDecorator={<PlayIcon size="14px" />}
onClick={async () => { onClick={async () => {
// Use fetch API to call endpoint const model = job_data.model_name;
await fetch( console.log(job_data)
chatAPI.Endpoints.Jobs.Create( const dataset = job_data.dataset;
experimentId,
'TRAIN', const models_downloaded = await fetch(
'QUEUED', chatAPI.Endpoints.Models.LocalList()
JSON.stringify(job_data) ).then((response) => {
// First check that the API responded correctly
if (response.ok) {
return response.json();
} else {
const error_msg = `${response.statusText}`;
throw new Error(error_msg);
}
})
.then((data) => {
// Then check the API responose to see if there was an error.
console.log('Server response:', data);
if (data?.status == "error") {
throw new Error(data.message);
}
return data;
})
.catch((error) => {
alert(error);
return false;
});
let modelInLocalList = false;
models_downloaded.forEach(modelData => {
if (modelData.model_id == model) {
modelInLocalList = true;
}
});
const datasets_downloaded = await fetch(
chatAPI.Endpoints.Dataset.LocalList()
).then((response) => {
// First check that the API responded correctly
if (response.ok) {
return response.json();
} else {
const error_msg = `${response.statusText}`;
throw new Error(error_msg);
}
})
.then((data) => {
// Then check the API responose to see if there was an error.
console.log('Server response:', data);
if (data?.status == "error") {
throw new Error(data.message);
}
return data;
})
.catch((error) => {
alert(error);
return false;
});
let datasetInLocalList = false;
datasets_downloaded.forEach(datasetData => {
if (datasetData.dataset_id == dataset) {
datasetInLocalList = true;
}
});
if(modelInLocalList && datasetInLocalList){
// Use fetch API to call endpoint
await fetch(
chatAPI.Endpoints.Jobs.Create(
experimentId,
'TRAIN',
'QUEUED',
JSON.stringify(job_data)
)
) )
) .then((response) => response.json())
.then((response) => response.json()) .then((data) => console.log(data))
.then((data) => console.log(data)) .catch((error) => console.log(error));
.catch((error) => console.log(error)); jobsMutate();
jobsMutate(); }
}} else{
let msg = "Warning: To use this recipe you will need to download the following:";
let shouldDownload = false;
if (!datasetInLocalList) {
msg += "\n- Dataset: " + dataset;
}
if (!modelInLocalList) {
msg += "\n- Model: " + model;
}
msg += "\n\nDo you want to download these now?";
if (confirm(msg)) { // Use confirm() to get Accept/Cancel
if (!datasetInLocalList) {
fetch(chatAPI.Endpoints.Dataset.Download(dataset))
.then((response) => {
if (!response.ok) {
console.log(response);
throw new Error(`HTTP Status: ${response.status}`);
}
return response.json();
})
.catch((error) => {
alert('Dataset download failed:\n' + error);
});
}
if (!modelInLocalList) {
chatAPI.downloadModelFromHuggingFace(model)
.then((response) => {
if (response.status == "error") {
console.log(response);
throw new Error(`${response.message}`);
}
return response;
})
.catch((error) => {
alert('Model download failed:\n' + error);
});
}
} else {
// User pressed Cancel
alert("Downloads cancelled. This recipe might not work correctly.");
}
}
}
}
> >
{initialMessage} {initialMessage}
</Button> </Button>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment