diff --git a/src/renderer/components/Experiment/DynamicPluginForm.tsx b/src/renderer/components/Experiment/DynamicPluginForm.tsx index 9e9733b5598792a966d4e52e5a9f23463a8f8b5d..a2455bad467f401268d0bbbddc49b007178be276 100644 --- a/src/renderer/components/Experiment/DynamicPluginForm.tsx +++ b/src/renderer/components/Experiment/DynamicPluginForm.tsx @@ -22,10 +22,11 @@ import { Slider, Stack, Option, - Autocomplete + Autocomplete, } from '@mui/joy'; import { useMemo } from 'react'; import ModelProviderWidget from 'renderer/components/Experiment/Widgets/ModelProviderWidget'; +import CustomEvaluationWidget from './Widgets/CustomEvaluationWidget'; import { RegistryWidgetsType, @@ -421,7 +422,13 @@ function CustomAutocompleteWidget<T = any, S extends StrictRJSFSchema = RJSFSche // Determine default value. const defaultValue = _multiple ? [] : ''; // Use the provided value or fallback to default. - const currentValue = value !== undefined ? value : defaultValue; + let currentValue = value !== undefined ? value : defaultValue; + + // Check if currentValue is an array, if a string, convert it to an array. + const isString = typeof currentValue === 'string'; + if (isString) { + currentValue = currentValue.split(','); + } // Map enumOptions into objects with label and value. const processedOptionsValues = enumOptions.map((opt) => @@ -492,6 +499,7 @@ const widgets: RegistryWidgetsType = { RangeWidget: CustomRange, SelectWidget: CustomSelectSimple, AutoCompleteWidget: CustomAutocompleteWidget, + EvaluationWidget: CustomEvaluationWidget, ModelProviderWidget: ModelProviderWidget }; diff --git a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx index d72c74318744419021b23aa1f678c28316ed79b1..ba8444fafeb961d13ca888e98e9a5393619d67e8 100644 --- a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx +++ b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx @@ -8,6 +8,7 @@ import { Table, Typography, Link, + Checkbox, } from '@mui/joy'; import { ChartColumnBigIcon, @@ -15,11 +16,13 @@ import { FileDigitIcon, Grid3X3Icon, Trash2Icon, + LineChartIcon, Type, } from 'lucide-react'; import { useState, useEffect } from 'react'; import useSWR from 'swr'; import * as chatAPI from '../../../lib/transformerlab-api-sdk'; +import TensorboardModal from '../Train/TensorboardModal'; import ViewOutputModalStreaming from './ViewOutputModalStreaming'; import ViewCSVModal from './ViewCSVModal'; import ViewPlotModal from './ViewPlotModal'; @@ -89,13 +92,14 @@ function RenderScore({ score }) { } const EvalJobsTable = () => { + const [selected, setSelected] = useState<readonly string[]>([]); const [viewOutputFromJob, setViewOutputFromJob] = useState(-1); const [openCSVModal, setOpenCSVModal] = useState(false); const [openPlotModal, setOpenPlotModal] = useState(false); const [currentJobId, setCurrentJobId] = useState(''); const [currentScore, setCurrentScore] = useState(''); - const [fileNameForDetailedReport, setFileNameForDetailedReport] = - useState(''); + const [currentTensorboardForModal, setCurrentTensorboardForModal] = useState(-1); + const [fileNameForDetailedReport, setFileNameForDetailedReport] = useState(''); const fetchCSV = async (jobId) => { const response = await fetch( @@ -112,6 +116,7 @@ const EvalJobsTable = () => { mutate: jobsMutate, } = useSWR(chatAPI.Endpoints.Jobs.GetJobsOfType('EVAL', ''), fetcher, { refreshInterval: 2000, + fallbackData: [], }); const handleOpenCSVModal = (jobId) => { @@ -149,11 +154,58 @@ const EvalJobsTable = () => { setFileName={setFileNameForDetailedReport} fileName={fileNameForDetailedReport} /> - <Typography level="h3">Executions</Typography> + <TensorboardModal + currentTensorboard={currentTensorboardForModal} + setCurrentTensorboard={setCurrentTensorboardForModal} + /> + <Box + sx={{ + display: 'flex', + justifyContent: 'space-between', + alignItems: 'baseline', + }} + > + <Typography level="h3">Executions</Typography> + {selected.length > 1 && ( + <Typography + level="body-sm" + startDecorator={<ChartColumnIncreasingIcon size="20px" />} + onClick={() => { + alert('this feature coming soon'); + }} + sx={{ cursor: 'pointer' }} + > + <>Compare Selected Evals</> + </Typography> + )} + </Box> + <Sheet sx={{ overflowY: 'scroll' }}> <Table stickyHeader> <thead> <tr> + <th + style={{ width: 48, textAlign: 'center', padding: '6px 6px' }} + > + <Checkbox + size="sm" + indeterminate={ + selected.length > 0 && selected.length !== jobs.length + } + checked={selected.length === jobs.length} + onChange={(event) => { + setSelected( + event.target.checked ? jobs.map((row) => row.id) : [] + ); + }} + color={ + selected.length > 0 || selected.length === jobs.length + ? 'primary' + : undefined + } + sx={{ verticalAlign: 'text-bottom' }} + /> + </th> <th width="50px">Id</th> <th>Eval</th> <th>Progress</th> @@ -164,6 +216,24 @@ const EvalJobsTable = () => { <tbody> {jobs?.map((job) => ( <tr key={job.id}> + <td style={{ textAlign: 'center', width: 120 }}> + <Checkbox + size="sm" + checked={selected.includes(job?.id)} + color={selected.includes(job?.id) ? 'primary' : undefined} + onChange={(event) => { + setSelected((ids) => + event.target.checked + ? ids.concat(job?.id) + : ids.filter((itemId) => itemId !== job?.id) + ); + }} + slotProps={{ + checkbox: { sx: { textAlign: 'left' } }, + }} + sx={{ verticalAlign: 'text-bottom' }} + /> + </td> <td>{job.id}</td> <td> <Typography level="title-md"> @@ -179,7 +249,6 @@ const EvalJobsTable = () => { <td> <JobProgress job={job} /> </td> - <td> <RenderScore score={job?.job_data?.score} /> {job?.job_data?.additional_output_path && @@ -219,12 +288,23 @@ const EvalJobsTable = () => { </Link> )} </td> - <td> <ButtonGroup variant="soft" sx={{ justifyContent: 'flex-end' }} > + {job?.job_data?.tensorboard_output_dir && ( + <Button + size="sm" + variant="plain" + onClick={() => { + setCurrentTensorboardForModal(job?.id); + }} + startDecorator={<LineChartIcon />} + > + Tensorboard + </Button> + )} <Button onClick={() => { setViewOutputFromJob(job?.id); diff --git a/src/renderer/components/Experiment/Eval/EvalModal.tsx b/src/renderer/components/Experiment/Eval/EvalModal.tsx index 04dde6ec240c8a21929d98288010c0a4cb7b50ed..e005a54be5ab5b28e508a2fc4e64911b254a947a 100644 --- a/src/renderer/components/Experiment/Eval/EvalModal.tsx +++ b/src/renderer/components/Experiment/Eval/EvalModal.tsx @@ -287,7 +287,7 @@ export default function EvalModal({ } else { console.log('formJson:', formJson); const template_name = formJson.template_name; - delete formJson.template_name; + // delete formJson.template_name; const result = await chatAPI.EXPERIMENT_ADD_EVALUATION( experimentInfo?.id, template_name, diff --git a/src/renderer/components/Experiment/Eval/EvalTasksTable.tsx b/src/renderer/components/Experiment/Eval/EvalTasksTable.tsx index 20b8861dad8e6dd94844c63727342dcf5e4ba8ef..68220d9b06ec180e2d05df2968563d10695d82df 100644 --- a/src/renderer/components/Experiment/Eval/EvalTasksTable.tsx +++ b/src/renderer/components/Experiment/Eval/EvalTasksTable.tsx @@ -20,7 +20,20 @@ function formatTemplateConfig(script_parameters): ReactElement { // Remove the author/full path from the model name for cleanliness // const short_model_name = c.model_name.split('/').pop(); // Set main_task as either or the metric name from the script parameters - const main_task = script_parameters.tasks + const main_task = (() => { + if (script_parameters.tasks) { + try { + const tasksArray = JSON.parse(script_parameters.tasks); + if (Array.isArray(tasksArray)) { + return tasksArray.map((task) => task.name).join(', '); + } + } catch (error) { + // Invalid JSON; fall back to the original value + } + return script_parameters.tasks; + } + return script_parameters.tasks; + })(); const dataset_name = script_parameters.dataset_name ? script_parameters.dataset_name : 'N/A'; diff --git a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx index 40e8b81517f5b8fe12403a7e7405fd4e3eb35962..e933f1c9f63d27f72699b673da332baefce2f65e 100644 --- a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx @@ -75,7 +75,8 @@ function formatEvalData(data) { } function formatArrayOfScores(scores) { - const formattedScores = scores.map((score) => { + const scoresArray = Array.isArray(scores) ? scores : [scores]; + const formattedScores = scoresArray.map((score) => { const metricName = Object.keys(score)[0]; const value = Object.values(score)[0]; diff --git a/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx b/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx index e61cfee3f150eccdd95b04005c69c21c54fdb29b..0c085aaf62416fb32d0906c9dd3823985f54905b 100644 --- a/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx +++ b/src/renderer/components/Experiment/Train/LoRATrainingRunButton.tsx @@ -50,7 +50,7 @@ export default function LoRATrainingRunButton({ }); let modelInLocalList = false; models_downloaded.forEach(modelData => { - if (modelData.model_id == model) { + if (modelData.model_id == model || modelData.local_path === model) { modelInLocalList = true; } }); @@ -85,7 +85,7 @@ export default function LoRATrainingRunButton({ datasetInLocalList = true; } }); - + if(modelInLocalList && datasetInLocalList){ // Use fetch API to call endpoint await fetch( @@ -108,6 +108,7 @@ export default function LoRATrainingRunButton({ if (!datasetInLocalList) { msg += "\n- Dataset: " + dataset; } + if (!modelInLocalList) { msg += "\n- Model: " + model; } diff --git a/src/renderer/components/Experiment/Widgets/CustomEvaluationWidget.tsx b/src/renderer/components/Experiment/Widgets/CustomEvaluationWidget.tsx new file mode 100644 index 0000000000000000000000000000000000000000..2ffc64ca24df8005cc5147e27d3729e072a4eb70 --- /dev/null +++ b/src/renderer/components/Experiment/Widgets/CustomEvaluationWidget.tsx @@ -0,0 +1,134 @@ +import React from 'react'; +import { WidgetProps } from '@rjsf/core'; +import { Button, Input, Select, Option } from '@mui/joy'; + +type EvaluationField = { + name: string; + expression: string; + return_type: string; +}; + +const parseValue = (val: any): EvaluationField[] => { + if (Array.isArray(val)) { + if (val.every(item => typeof item === "string")) { + // If every element is a string: join them and parse the result. + try { + const joined = val.join(','); + const parsed = JSON.parse(joined); + return Array.isArray(parsed) ? parsed : []; + } catch (err) { + console.error("Error parsing evaluation widget value:", err); + return []; + } + } else { + // If not all elements are strings, assume it's already an array of EvaluationField. + return val; + } + } else if (typeof val === "string") { + try { + return JSON.parse(val); + } catch (err) { + console.error("Error parsing evaluation widget value string:", err); + return []; + } + } + return []; +}; + +const CustomEvaluationWidget = (props: WidgetProps<any>) => { + const { id, value, onChange, disabled, readonly } = props; + + // Directly derive evaluation metrics from the value prop. + const evalMetrics: EvaluationField[] = React.useMemo(() => parseValue(value), [value]); + + const handleAddField = () => { + const updatedMetrics = [ + ...evalMetrics, + { name: '', expression: '', return_type: 'boolean' } + ]; + onChange(updatedMetrics); + }; + + const handleFieldChange = ( + index: number, + field: keyof EvaluationField, + newValue: string + ) => { + const updated = evalMetrics.map((evaluation, i) => + i === index ? { ...evaluation, [field]: newValue } : evaluation + ); + onChange(updated); + }; + + const handleRemoveField = (index: number) => { + const updated = evalMetrics.filter((_, i) => i !== index); + onChange(updated); + }; + + return ( + <div id={id}> + {evalMetrics.map((evaluation, index) => ( + <div + key={index} + style={{ + marginBottom: '1rem', + border: '1px solid #ccc', + padding: '0.5rem' + }} + > + <Input + placeholder="Evaluation Name" + value={evaluation.name} + onChange={(e) => + handleFieldChange(index, 'name', e.target.value) + } + disabled={disabled || readonly} + style={{ marginBottom: '0.5rem' }} + /> + <textarea + placeholder="Regular Expression/String" + value={evaluation.expression} + onChange={(e) => + handleFieldChange(index, 'expression', e.target.value) + } + disabled={disabled || readonly} + style={{ marginBottom: '0.5rem' }} + /> + <Select + placeholder="Output Type" + value={evaluation.return_type} + onChange={(e, newValue) => + handleFieldChange(index, 'return_type', newValue as string) + } + disabled={disabled || readonly} + style={{ marginBottom: '0.5rem' }} + > + <Option value="boolean">Boolean</Option> + <Option value="number">Number</Option> + <Option value="contains">Contains</Option> + <Option value="isequal">IsEqual</Option> + </Select> + <Button + onClick={() => handleRemoveField(index)} + disabled={disabled || readonly} + size="sm" + variant="outlined" + > + Remove Field + </Button> + </div> + ))} + <Button + onClick={handleAddField} + disabled={disabled || readonly} + variant="solid" + > + Add Field + </Button> + {/* Hidden input to capture the JSON result on form submission */} + <input type="hidden" id={id} name={id} value={JSON.stringify(evalMetrics)} /> + </div> + ); +}; + +export default CustomEvaluationWidget; diff --git a/src/renderer/components/ModelZoo/LocalModelsTable.tsx b/src/renderer/components/ModelZoo/LocalModelsTable.tsx index d678ce7d5105469af17c75444b8ddcf01b2cf081..0a8404c4ca81350ace0d64a1123ed88a411dcc2a 100644 --- a/src/renderer/components/ModelZoo/LocalModelsTable.tsx +++ b/src/renderer/components/ModelZoo/LocalModelsTable.tsx @@ -11,6 +11,7 @@ import { Option, } from '@mui/joy'; import { + ArrowRightToLineIcon, ArrowDownIcon, FlaskRoundIcon, InfoIcon, @@ -184,6 +185,14 @@ const LocalModelsTable = ({ marginRight: '5px', }} /> + ) : (row?.source && row?.source != "transformerlab") ? ( + <ArrowRightToLineIcon + color="var(--joy-palette-success-700)" + style={{ + verticalAlign: 'middle', + marginRight: '5px', + }} + /> ) : ( '' )}{' '} @@ -191,38 +200,38 @@ const LocalModelsTable = ({ </Typography> </td> <td> - <Typography style={{overflow: 'hidden'}}> - {' '} - {row?.json_data?.architecture == 'MLX' && ( - <> - <TinyMLXLogo /> - - </> - )} - {row?.json_data?.architecture == 'GGUF' && ( - <> - <img - src="https://avatars.githubusercontent.com/ggerganov" - width="24" - valign="middle" - style={{ borderRadius: '50%' }} - />{' '} - - </> - )} - {[ - 'FalconForCausalLM', - 'Gemma2ForCausalLM', - 'GPTBigCodeForCausalLM', - 'LlamaForCausalLM', - 'MistralForCausalLM', - 'Phi3ForCausalLM', - 'Qwen2ForCausalLM', - 'T5ForConditionalGeneration' - ].includes(row?.json_data?.architecture) && ( - <>🤗 </> - )} - {row?.json_data?.architecture} + <Typography style={{ overflow: 'hidden' }}> + {' '} + {row?.json_data?.architecture == 'MLX' && ( + <> + <TinyMLXLogo /> + + </> + )} + {row?.json_data?.architecture == 'GGUF' && ( + <> + <img + src="https://avatars.githubusercontent.com/ggerganov" + width="24" + valign="middle" + style={{ borderRadius: '50%' }} + />{' '} + + </> + )} + {[ + 'FalconForCausalLM', + 'Gemma2ForCausalLM', + 'GPTBigCodeForCausalLM', + 'LlamaForCausalLM', + 'MistralForCausalLM', + 'Phi3ForCausalLM', + 'Qwen2ForCausalLM', + 'T5ForConditionalGeneration' + ].includes(row?.json_data?.architecture) && ( + <>🤗 </> + )} + {row?.json_data?.architecture} </Typography> </td> <td>{row?.json_data?.parameters}</td> @@ -257,8 +266,8 @@ const LocalModelsTable = ({ if ( confirm( "Are you sure you want to delete model '" + - row.model_id + - "'?" + row.model_id + + "'?" ) ) { await fetch( diff --git a/src/renderer/components/Settings/TransformerLabSettings.tsx b/src/renderer/components/Settings/TransformerLabSettings.tsx index 0e2654a2e5b653fe649f626c1682a45597417b06..e845ec3c9ade830810a748cf767f4a426f56c2b4 100644 --- a/src/renderer/components/Settings/TransformerLabSettings.tsx +++ b/src/renderer/components/Settings/TransformerLabSettings.tsx @@ -206,6 +206,8 @@ export default function TransformerLabSettings() { <Option value="DOWNLOAD_MODEL">Download Model</Option> <Option value="LOAD_MODEL">Load Model</Option> <Option value="TRAIN">Train</Option> + <Option value="GENERATE">Generate</Option> + <Option value="EVAL">Evaluate</Option> </Select> {showJobsOfType !== 'NONE' && ( <Table sx={{ tableLayout: 'auto', overflow: 'scroll' }}>