From 8fb2d880bd5d2a6dd44ab2b1d1381d49a01fdea1 Mon Sep 17 00:00:00 2001
From: ali asaria <aliasaria@users.noreply.github.com>
Date: Fri, 12 Jul 2024 14:00:05 -0400
Subject: [PATCH] add stop button to UI

---
 .../components/Experiment/Train/TrainLoRA.tsx | 136 ++++++++++++++++--
 src/renderer/lib/transformerlab-api-sdk.ts    |   1 +
 2 files changed, 127 insertions(+), 10 deletions(-)

diff --git a/src/renderer/components/Experiment/Train/TrainLoRA.tsx b/src/renderer/components/Experiment/Train/TrainLoRA.tsx
index 197bb6be..64a61fa9 100644
--- a/src/renderer/components/Experiment/Train/TrainLoRA.tsx
+++ b/src/renderer/components/Experiment/Train/TrainLoRA.tsx
@@ -29,6 +29,8 @@ import {
   LineChartIcon,
   Plug2Icon,
   PlusIcon,
+  StopCircle,
+  StopCircleIcon,
   Trash2Icon,
 } from 'lucide-react';
 
@@ -281,6 +283,68 @@ export default function TrainLoRA({ experimentInfo }) {
                     </tr>
                   );
                 })}
+              {
+                // Format of template data by column:
+                // 0 = id, 1 = name, 2 = description, 3 = type, 4 = datasets, 5 = config, 6 = created, 7 = updated
+                data &&
+                  data?.map((row) => {
+                    return (
+                      <tr key={row[0]}>
+                        <td>
+                          <Typography level="title-sm">{row[1]}</Typography>
+                        </td>
+                        {/* <td>{row[2]}</td> */}
+                        <td>
+                          {row[4]} <FileTextIcon size={14} />
+                        </td>
+                        <td style={{ overflow: 'clip' }}>
+                          {formatTemplateConfig(row[5])}
+                        </td>
+                        <td style={{}}>
+                          <ButtonGroup sx={{ justifyContent: 'flex-end' }}>
+                            <LoRATrainingRunButton
+                              initialMessage="Queue"
+                              trainingTemplate={{
+                                template_id: row[0],
+                                template_name: row[1],
+                                model_name: row[5]?.model_name || 'unknown',
+                                dataset: row[4],
+                                config: row[5],
+                              }}
+                              jobsMutate={jobsMutate}
+                              experimentId={experimentInfo?.id}
+                            />
+                            <Button
+                              onClick={() => {
+                                setTemplateID(row[0]);
+                                setCurrentPlugin(
+                                  JSON.parse(row[5])?.plugin_name
+                                );
+                                setOpen(true);
+                              }}
+                              variant="plain"
+                            >
+                              Edit
+                            </Button>
+                            <IconButton
+                              onClick={async () => {
+                                await fetch(
+                                  chatAPI.API_URL() +
+                                    'train/template/' +
+                                    row[0] +
+                                    '/delete'
+                                );
+                                mutate();
+                              }}
+                            >
+                              <Trash2Icon />
+                            </IconButton>
+                          </ButtonGroup>
+                        </td>
+                      </tr>
+                    );
+                  })
+              }
             </tbody>
           </Table>
         </Sheet>
@@ -369,17 +433,69 @@ export default function TrainLoRA({ experimentInfo }) {
                                 .duration(
                                   dayjs(job?.job_data?.end_time).diff(
                                     dayjs(job?.job_data?.start_time)
+                        {' '}
+                        <Stack
+                          direction={'column'}
+                          justifyContent={'space-between'}
+                        >
+                          <Chip color={jobChipColor(job.status)}>
+                            {job.status}
+                            {job.progress == '-1'
+                              ? ''
+                              : ' - ' +
+                                Number.parseFloat(job.progress).toFixed(1) +
+                                '%'}
+                          </Chip>
+                          {job?.job_data?.start_time && (
+                            <>
+                              Started:{' '}
+                              {dayjs(job?.job_data?.start_time).fromNow()}
+                            </>
+                          )}
+                          <br />
+                          {/* {job?.job_data?.end_time &&
+                          dayjs(job?.job_data?.end_time).fromNow()} */}
+                          {job?.job_data?.start_time &&
+                          job?.job_data?.end_time ? (
+                            <>
+                              Completed in:{' '}
+                              {job?.job_data?.end_time &&
+                                job?.job_data?.end_time &&
+                                dayjs
+                                  .duration(
+                                    dayjs(job?.job_data?.end_time).diff(
+                                      dayjs(job?.job_data?.start_time)
+                                    )
                                   )
-                                )
-                                .humanize()}
-                          </>
-                        ) : (
-                          <LinearProgress
-                            determinate
-                            value={job.progress}
-                            sx={{ my: 1 }}
-                          />
-                        )}
+                                  .humanize()}
+                            </>
+                          ) : (
+                            <>
+                              {job.status == 'RUNNING' && (
+                                <Stack direction={'row'}>
+                                  <LinearProgress
+                                    determinate
+                                    value={job.progress}
+                                    sx={{ my: 1 }}
+                                  />
+                                  <IconButton
+                                    color="danger"
+                                    onClick={async () => {
+                                      confirm(
+                                        'Are you sure you want to stop this job?'
+                                      ) &&
+                                        (await fetch(
+                                          chatAPI.Endpoints.Jobs.Stop(job.id)
+                                        ));
+                                    }}
+                                  >
+                                    <StopCircleIcon />
+                                  </IconButton>
+                                </Stack>
+                              )}
+                            </>
+                          )}
+                        </Stack>
                       </td>
                       <td style={{}}>
                         <ButtonGroup sx={{ justifyContent: 'flex-end' }}>
diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts
index df8984ac..f82ef9e1 100644
--- a/src/renderer/lib/transformerlab-api-sdk.ts
+++ b/src/renderer/lib/transformerlab-api-sdk.ts
@@ -899,6 +899,7 @@ Endpoints.Jobs = {
     type +
     '&config=' +
     config,
+  Stop: (jobId: string) => API_URL() + 'jobs/' + jobId + '/stop',
 };
 
 Endpoints.Global = {
-- 
GitLab