diff --git a/src/renderer/components/Computer.tsx b/src/renderer/components/Computer.tsx index d5209d26c97a1e8d5f471046624384553d187dbe..0870fdc2f98c251449491c8c81498b6015591d11 100644 --- a/src/renderer/components/Computer.tsx +++ b/src/renderer/components/Computer.tsx @@ -80,245 +80,286 @@ export default function Computer() { ); return ( - <Tabs sx={{ height: '100%', overflow: 'hidden' }}> - <TabList> - <Tab>Server Information</Tab> - <Tab>Python Libraries</Tab> - </TabList> - <TabPanel value={0}> - {server && ( - <> - {/* {JSON.stringify(server)} */} - <Typography level="h2" paddingBottom={3}> - Server Information - </Typography> - <Sheet className="OrderTableContainer"> - <Grid container spacing={2} sx={{}}> - <Grid xs={2}> - <ComputerCard - icon={<FaComputer />} - title="Machine" - description={`${server.os} - ${server.name}`} - > - <StatRow title="CPU" value={server?.cpu_percent + '%'} /> - <StatRow title="Cores" value={server?.cpu_count} /> - </ComputerCard> - </Grid> - <Grid xs={4}> - <ComputerCard - icon={<BsGpuCard />} - title={'GPU Specs (' + server.gpu?.length + ')'} - image={undefined} - > - {server.gpu?.map((g, i) => { - return ( - <Box mb={2}> - <Typography level="title-md">GPU # {i}</Typography> - {g.name.includes('NVIDIA') ? ( - <SiNvidia color="#76B900" /> - ) : ( - '🔥' - )} - - {g.name} - <StatRow - title="Total VRAM" - value={formatBytes(g?.total_memory)} - /> - <StatRow - title="Available" - value={formatBytes(g?.free_memory)} - /> - {g.total_memory !== 'n/a' && ( - <> - <StatRow - title="Used" - value={ - <> - {Math.round( - (g?.used_memory / g?.total_memory) * 100, - )} - % - <LinearProgress - determinate - value={ - (g?.used_memory / g?.total_memory) * 100 - } - variant="solid" - sx={{ minWidth: '50px' }} - /> - </> - } - /> - </> - )} - </Box> - ); - })} - </ComputerCard> - </Grid> - <Grid xs={3}> - <ComputerCard icon={<ZapIcon />} title="Acceleration"> - <StatRow - title="GPU" - value={server.gpu?.length === 0 ? 'âŒ' : '✅'} - /> - <StatRow - title="CUDA" - value={server?.device === 'cuda' ? '✅ ' : 'âŒ'} - /> - <StatRow - title="CUDA Version" - value={server?.cuda_version} - />{' '} - <StatRow - title="Python MPS" - value={server?.device === 'mps' ? '✅ ' : 'âŒ'} - />{' '} - <StatRow - title="Flash Attention" - value={ - server?.flash_attn_version && - server?.flash_attn_version != 'n/a' - ? '✅' - : 'âŒ' - } - /> - <StatRow - title="Flash Attn Version" - value={server?.flash_attn_version} - /> - </ComputerCard> - </Grid> - <Grid xs={3}> - <ComputerCard icon={<LayoutIcon />} title="Operating System"> - {server?.platform.includes('microsoft') && <FaWindows />} - {server?.platform} - </ComputerCard> - </Grid> - <Grid xs={3}> - <ComputerCard icon={<MemoryStickIcon />} title="Memory"> - <> + <Sheet + sx={{ + display: 'flex', + flexDirection: 'column', + height: '100%', + overflow: 'hidden', + paddingBottom: '1rem', + }} + > + <Tabs + sx={{ + height: '100%', + display: 'block', + overflow: 'hidden', + }} + > + <TabList> + <Tab>Server Information</Tab> + <Tab>Python Libraries</Tab> + </TabList> + <TabPanel + value={0} + sx={{ + overflow: 'hidden', + height: '100%', + }} + > + {server && ( + <Sheet + sx={{ + display: 'flex', + flexDirection: 'column', + height: '100%', + overflow: 'hidden', + paddingBottom: '1rem', + }} + > + {/* {JSON.stringify(server)} */} + <Typography level="h2" paddingBottom={1}> + Server Information + </Typography> + <Sheet + className="OrderTableContainer" + sx={{ + display: 'flex', + height: '100%', + overflowY: 'auto', + padding: '10px', + }} + > + <Grid container spacing={2} sx={{}}> + <Grid xs={2}> + <ComputerCard + icon={<FaComputer />} + title="Machine" + description={`${server.os} - ${server.name}`} + > + <StatRow title="CPU" value={server?.cpu_percent + '%'} /> + <StatRow title="Cores" value={server?.cpu_count} /> + </ComputerCard> + </Grid> + <Grid xs={4}> + <ComputerCard + icon={<BsGpuCard />} + title={'GPU Specs (' + server.gpu?.length + ')'} + image={undefined} + > + {server.gpu?.map((g, i) => { + return ( + <Box mb={2}> + <Typography level="title-md">GPU # {i}</Typography> + {g.name.includes('NVIDIA') ? ( + <SiNvidia color="#76B900" /> + ) : ( + '🔥' + )} + + {g.name} + <StatRow + title="Total VRAM" + value={formatBytes(g?.total_memory)} + /> + <StatRow + title="Available" + value={formatBytes(g?.free_memory)} + /> + {g.total_memory !== 'n/a' && ( + <> + <StatRow + title="Used" + value={ + <> + {Math.round( + (g?.used_memory / g?.total_memory) * + 100, + )} + % + <LinearProgress + determinate + value={ + (g?.used_memory / g?.total_memory) * + 100 + } + variant="solid" + sx={{ minWidth: '50px' }} + /> + </> + } + /> + </> + )} + </Box> + ); + })} + </ComputerCard> + </Grid> + <Grid xs={3}> + <ComputerCard icon={<ZapIcon />} title="Acceleration"> + <StatRow + title="GPU" + value={server.gpu?.length === 0 ? 'âŒ' : '✅'} + /> + <StatRow + title="CUDA" + value={server?.device === 'cuda' ? '✅ ' : 'âŒ'} + /> + <StatRow + title="CUDA Version" + value={server?.cuda_version} + />{' '} + <StatRow + title="Python MPS" + value={server?.device === 'mps' ? '✅ ' : 'âŒ'} + />{' '} + <StatRow + title="Flash Attention" + value={ + server?.flash_attn_version && + server?.flash_attn_version != 'n/a' + ? '✅' + : 'âŒ' + } + /> + <StatRow + title="Flash Attn Version" + value={server?.flash_attn_version} + /> + </ComputerCard> + </Grid> + <Grid xs={3}> + <ComputerCard + icon={<LayoutIcon />} + title="Operating System" + > + {server?.platform.includes('microsoft') && <FaWindows />} + {server?.platform} + </ComputerCard> + </Grid> + <Grid xs={3}> + <ComputerCard icon={<MemoryStickIcon />} title="Memory"> + <> + <StatRow + title="Total" + value={formatBytes(server.memory?.total)} + /> + <StatRow + title="Available" + value={formatBytes(server.memory?.available)} + /> + <StatRow + title="Percent" + value={server.memory?.percent + '%'} + /> + </> + </ComputerCard> + </Grid> + <Grid xs={3}> + <ComputerCard title="Disk" icon={<DatabaseIcon />}> <StatRow title="Total" - value={formatBytes(server.memory?.total)} + value={formatBytes(server.disk?.total)} + /> + <StatRow + title="Used" + value={formatBytes(server.disk?.used)} /> <StatRow - title="Available" - value={formatBytes(server.memory?.available)} + title="Free" + value={formatBytes(server.disk?.free)} /> <StatRow title="Percent" - value={server.memory?.percent + '%'} + value={ + <> + {server.disk?.percent}% + <LinearProgress + determinate + value={server.disk?.percent} + variant="solid" + sx={{ minWidth: '50px' }} + /> + </> + } /> - </> - </ComputerCard> + </ComputerCard> + </Grid> + <Grid xs={3}> + <ComputerCard icon={<FaPython />} title="Python Version"> + {server.python_version} + </ComputerCard> + </Grid> </Grid> - <Grid xs={3}> - <ComputerCard title="Disk" icon={<DatabaseIcon />}> - <StatRow - title="Total" - value={formatBytes(server.disk?.total)} - /> - <StatRow - title="Used" - value={formatBytes(server.disk?.used)} - /> - <StatRow - title="Free" - value={formatBytes(server.disk?.free)} - /> - <StatRow - title="Percent" - value={ - <> - {server.disk?.percent}% - <LinearProgress - determinate - value={server.disk?.percent} - variant="solid" - sx={{ minWidth: '50px' }} - /> - </> - } - /> - </ComputerCard> - </Grid> - <Grid xs={3}> - <ComputerCard icon={<FaPython />} title="Python Version"> - {server.python_version} - </ComputerCard> - </Grid> - </Grid> + </Sheet> </Sheet> - </> - )} - </TabPanel> - <TabPanel - value={1} - style={{ - display: 'flex', - flexDirection: 'column', - height: '100%', - overflow: 'hidden', - }} - > - <Sheet + )} + </TabPanel> + <TabPanel + value={1} style={{ - display: 'flex', - flexDirection: 'column', height: '100%', - overflow: 'hidden', - gap: '1rem', + overflow: 'auto', }} > - <Typography level="h2" paddingTop={2}> - Installed Python Libraries - </Typography> - <Typography level="title-sm" paddingBottom={0}> - Conda Environment: {server?.conda_environment} @{' '} - {server?.conda_prefix} - </Typography> - <FormControl size="sm" sx={{ width: '400px' }}> - <Input - placeholder="Search" - value={searchText} - onChange={(e) => setSearchText(e.target.value)} - startDecorator={<SearchIcon />} - /> - </FormControl> - {pythonLibraries && ( - <> - <Sheet sx={{ overflow: 'auto', width: 'fit-content' }}> - <Table borderAxis="both" sx={{ width: 'auto' }}> - <thead> - <tr> - <th>Library</th> - <th>Version</th> - </tr> - </thead> - <tbody> - {pythonLibraries - .filter((lib) => - lib.name - .toLowerCase() - .includes(searchText.toLowerCase()), - ) - .map((lib) => { - return ( - <tr> - <td>{lib.name}</td> - <td>{lib.version}</td> - </tr> - ); - })} - </tbody> - </Table> - </Sheet> - </> - )} - </Sheet> - </TabPanel> - </Tabs> + <Sheet + style={{ + display: 'flex', + flexDirection: 'column', + height: '100%', + overflow: 'hidden', + gap: '1rem', + }} + > + <Typography level="h2" paddingTop={0}> + Installed Python Libraries + </Typography> + <Typography level="title-sm" paddingBottom={0}> + Conda Environment: {server?.conda_environment} @{' '} + {server?.conda_prefix} + </Typography> + <FormControl size="sm" sx={{ width: '400px' }}> + <Input + placeholder="Search" + value={searchText} + onChange={(e) => setSearchText(e.target.value)} + startDecorator={<SearchIcon />} + /> + </FormControl> + {pythonLibraries && ( + <> + <Sheet sx={{ overflow: 'auto', width: 'fit-content' }}> + <Table borderAxis="both" sx={{ width: 'auto' }}> + <thead> + <tr> + <th>Library</th> + <th>Version</th> + </tr> + </thead> + <tbody> + {pythonLibraries + .filter((lib) => + lib.name + .toLowerCase() + .includes(searchText.toLowerCase()), + ) + .map((lib) => { + return ( + <tr> + <td>{lib.name}</td> + <td>{lib.version}</td> + </tr> + ); + })} + </tbody> + </Table> + </Sheet> + </> + )} + </Sheet> + </TabPanel> + </Tabs> + </Sheet> ); } diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx index 8bb40e9d7247ad3be5a06754a035bcac0d088ab4..d3de4b8acd2f7af937b7df9dfe66ee7e55c75b03 100644 --- a/src/renderer/components/Experiment/Eval/Chart.tsx +++ b/src/renderer/components/Experiment/Eval/Chart.tsx @@ -23,14 +23,18 @@ const Chart = ({ metrics, compareChart }) => { // Normalize data structure regardless of compare mode // Get all unique metric types and series keys - const metricTypes = Array.from(new Set(metrics.map(m => m.type))); + const metricTypes = Array.from(new Set(metrics.map((m) => m.type))); const seriesKeys = Array.from( - new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score')) + new Set( + metrics.map((m) => + compareChart ? `${m.evaluator}-${m.job_id}` : 'score', + ), + ), ); // Create a consistent structure for all modes const dataMap = {}; - metrics.forEach(metric => { + metrics.forEach((metric) => { const { type, evaluator, job_id, score } = metric; const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score'; @@ -43,7 +47,7 @@ const Chart = ({ metrics, compareChart }) => { const normalizedData = { dataPoints: Object.values(dataMap), metricTypes, - seriesKeys + seriesKeys, }; // Get transformed data for the right chart type @@ -53,24 +57,30 @@ const Chart = ({ metrics, compareChart }) => { if (chartType === 'line') { if (swapAxes) { // Series are metric types - return metricTypes.map(type => ({ + return metricTypes.map((type) => ({ id: type, - data: seriesKeys.map(seriesKey => { - const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined); - return { - x: seriesKey, - y: matchingPoint ? matchingPoint[seriesKey] : null - }; - }).filter(point => point.y !== null) + data: seriesKeys + .map((seriesKey) => { + const matchingPoint = dataPoints.find( + (p) => p.metric === type && p[seriesKey] !== undefined, + ); + return { + x: seriesKey, + y: matchingPoint ? matchingPoint[seriesKey] : null, + }; + }) + .filter((point) => point.y !== null), })); } else { // Series are evaluators/jobs - return seriesKeys.map(seriesKey => ({ + return seriesKeys.map((seriesKey) => ({ id: seriesKey, - data: dataPoints.map(point => ({ - x: point.metric, - y: point[seriesKey] - })).filter(point => point.y !== undefined) + data: dataPoints + .map((point) => ({ + x: point.metric, + y: point[seriesKey], + })) + .filter((point) => point.y !== undefined), })); } } else if (chartType === 'bar' || chartType === 'radar') { @@ -82,8 +92,15 @@ const Chart = ({ metrics, compareChart }) => { }; return ( - <> - <Box sx={{ display: 'flex', gap: 2, mb: 2, alignItems: 'center' }}> + <Box sx={{ border: '1px solid #e0e0e0', borderRadius: 8, p: 2 }}> + <Box + sx={{ + display: 'flex', + gap: 2, + mb: 2, + alignItems: 'center', + }} + > <FormControl sx={{ width: 200 }}> <Select value={chartType} onChange={handleChartTypeChange}> <Option value="bar">Bar</Option> @@ -107,7 +124,7 @@ const Chart = ({ metrics, compareChart }) => { {chartType === 'line' && ( <ResponsiveLine data={getChartData()} - margin={{ top: 50, right: 110, bottom: 50, left: 60 }} + margin={{ top: 50, right: 200, bottom: 80, left: 60 }} xScale={{ type: 'point' }} yScale={{ type: 'linear', @@ -140,7 +157,7 @@ const Chart = ({ metrics, compareChart }) => { pointBorderColor={{ from: 'serieColor' }} legends={[ { - anchor: 'bottom-right', + anchor: 'top-right', direction: 'column', justify: false, translateX: 100, @@ -153,7 +170,7 @@ const Chart = ({ metrics, compareChart }) => { symbolSize: 12, symbolShape: 'circle', symbolBorderColor: 'rgba(0, 0, 0, .5)', - } + }, ]} /> )} @@ -165,7 +182,9 @@ const Chart = ({ metrics, compareChart }) => { indexBy="metric" margin={{ top: 50, right: 130, bottom: 50, left: 60 }} padding={0.3} - groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'} + groupMode={ + normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked' + } axisTop={null} axisRight={null} axisBottom={{ @@ -210,31 +229,35 @@ const Chart = ({ metrics, compareChart }) => { enableDotLabel={true} dotLabel="value" dotLabelYOffset={-12} - legends={normalizedData.seriesKeys.length > 1 ? [ - { - anchor: 'right', - direction: 'column', - translateX: 50, - translateY: 0, - itemWidth: 120, - itemHeight: 20, - itemTextColor: '#999', - symbolSize: 12, - symbolShape: 'circle', - effects: [ - { - on: 'hover', - style: { - itemTextColor: '#000' - } - } - ] - } - ] : []} + legends={ + normalizedData.seriesKeys.length > 0 + ? [ + { + anchor: 'top-right', + direction: 'column', + translateX: -200, + translateY: 0, + itemWidth: 120, + itemHeight: 20, + itemTextColor: '#999', + symbolSize: 12, + symbolShape: 'circle', + effects: [ + { + on: 'hover', + style: { + itemTextColor: '#000', + }, + }, + ], + }, + ] + : [] + } /> )} </div> - </> + </Box> ); }; diff --git a/src/renderer/components/Experiment/Eval/EvalModal.tsx b/src/renderer/components/Experiment/Eval/EvalModal.tsx index e005a54be5ab5b28e508a2fc4e64911b254a947a..e4c2d7fd36bd5871f38bdf40d8885b7e9b1a15ba 100644 --- a/src/renderer/components/Experiment/Eval/EvalModal.tsx +++ b/src/renderer/components/Experiment/Eval/EvalModal.tsx @@ -312,15 +312,16 @@ export default function EvalModal({ } }; + return ( <Modal open={open}> <ModalDialog sx={{ - width: '80dvw', + width: '95dvw', transform: 'translateX(-50%)', // This undoes the default translateY that centers vertically top: '5dvh', overflow: 'auto', - maxHeight: '90dvh', + maxHeight: '92dvh', minHeight: '70dvh', height: '100%', }} @@ -362,7 +363,7 @@ export default function EvalModal({ </TabPanel> <TabPanel value={2} - sx={{ p: 2, overflow: 'auto', maxWidth: '700px' }} + sx={{ p: 2, overflow: 'auto'}} keepMounted > <DynamicPluginForm diff --git a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx index 91583c70846f8a3c253fe923231212e0a69ab93b..267812dd7978e25ee09ad033046c2bb001e7eca2 100644 --- a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx @@ -9,7 +9,13 @@ function parseJSON(data) { } } -export default function ViewPlotModal({ open, onClose, data, jobId, compareChart= false}) { +export default function ViewPlotModal({ + open, + onClose, + data, + jobId, + compareChart = false, +}) { if (!jobId) { return <></>; } @@ -29,9 +35,6 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart height: '100%', }} > - <Typography level="h4" mb={2}> - Chart - </Typography> <Box sx={{ width: '100%', @@ -42,9 +45,12 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart borderRadius: '8px', boxShadow: 1, p: 2, + display: 'flex', + flexDirection: 'column', + justifyContent: 'center', }} > - <Chart metrics={parseJSON(data)} compareChart={compareChart} /> + <Chart metrics={parseJSON(data)} compareChart={compareChart} /> </Box> </Box> </ModalDialog> diff --git a/src/renderer/components/Experiment/Foundation/CurrentFoundationInfo.tsx b/src/renderer/components/Experiment/Foundation/CurrentFoundationInfo.tsx index 4d870b32de831244ae66b6861324be75df2b2d60..3e9b9967877d43c22ab7b0b3b7626b9ed1a11e29 100644 --- a/src/renderer/components/Experiment/Foundation/CurrentFoundationInfo.tsx +++ b/src/renderer/components/Experiment/Foundation/CurrentFoundationInfo.tsx @@ -3,6 +3,7 @@ import Sheet from '@mui/joy/Sheet'; import { Box, Button, IconButton, Stack, Table, Typography } from '@mui/joy'; +import Tooltip from '@mui/joy/Tooltip'; import { BabyIcon, DotIcon, Trash2Icon, XCircleIcon } from 'lucide-react'; import useSWR from 'swr'; @@ -15,6 +16,9 @@ const fetchWithPost = ({ url, post }) => method: 'POST', body: post, }).then((res) => res.json()); + +const fetcher = (url) => fetch(url).then((res) => res.json()); + function modelNameIsInHuggingfaceFormat(modelName: string) { return modelName.includes('/'); } @@ -62,8 +66,15 @@ export default function CurrentFoundationInfo({ fetchWithPost ); const [huggingfaceData, setHugggingfaceData] = useState({}); + const [showProvenance, setShowProvenance] = useState(false); const huggingfaceId = experimentInfo?.config?.foundation; + // Fetch provenance data from your GET endpoint using chatAPI.Endpoints.Models.ModelProvenance() + const { data: provenance, error: provenanceError } = useSWR( + chatAPI.Endpoints.Models.ModelProvenance(huggingfaceId), + fetcher + ); + useMemo(() => { // This is a local model if (experimentInfo?.config?.foundation_filename) { @@ -114,6 +125,95 @@ export default function CurrentFoundationInfo({ )} </tbody> </Table> + {/* Model Provenance Collapsible */} + <Box mt={4}> + <Button + variant="soft" + onClick={() => setShowProvenance((prev) => !prev)} + > + Model Provenance {showProvenance ? 'â–²' : 'â–¼'} + </Button> + {showProvenance && ( + <Box + sx={{ + mt: 2, + overflow: 'auto', + maxHeight: 400, + maxWidth: '100%', + border: '1px solid #ccc', + borderRadius: '4px', + }} + > + {provenance ? ( + <Table + id="model-provenance-table" + sx={{ + tableLayout: 'auto', + minWidth: 600, // Ensure horizontal scroll if needed + }} + > + <thead> + <tr> + <th>Job ID</th> + <th>Base Model</th> + <th>Dataset</th> + <th>Params</th> + <th>Output Model</th> + <th>Evals</th> + </tr> + </thead> + <tbody> + {provenance.provenance_chain.map((row) => ( + <tr key={row.job_id}> + <td>{row.job_id}</td> + <td>{row.input_model}</td> + <td>{row.dataset}</td> + <td> + <pre>{JSON.stringify(row.parameters, null, 2)}</pre> + </td> + <td>{row.output_model}</td> + <td> + <Box> + {row.evals && row.evals.length > 0 ? ( + row.evals.map((evalItem) => ( + <Tooltip + key={evalItem.job_id} + title={ + <pre style={{ margin: 0 }}> + {JSON.stringify(evalItem, null, 2)} + </pre> + } + > + <Typography + level="body2" + sx={{ cursor: 'pointer', mb: 0.5 }} + > + {evalItem.job_id} -{' '} + {evalItem.template_name || + evalItem.evaluator || + 'Eval'} + </Typography> + </Tooltip> + )) + ) : ( + <Typography level="body2"> + No Evals + </Typography> + )} + </Box> + </td> + </tr> + ))} + </tbody> + </Table> + ) : provenanceError ? ( + <Typography>Error loading provenance</Typography> + ) : ( + <Typography>Loading Provenance...</Typography> + )} + </Box> + )} + </Box> </Box> <Box flex={1}> <Typography level="title-lg" marginTop={1} marginBottom={1}> diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts index 7034f8b99a1954072805ab6617fd7b54c9427c02..819a33154002cbc8cfc59cf243ff344c369616f6 100644 --- a/src/renderer/lib/transformerlab-api-sdk.ts +++ b/src/renderer/lib/transformerlab-api-sdk.ts @@ -1082,6 +1082,8 @@ Endpoints.Models = { API_URL() + 'model/gallery/' + convertSlashInUrl(modelId), ModelDetailsFromFilesystem: (modelId: string) => API_URL() + 'model/details/' + convertSlashInUrl(modelId), + ModelProvenance: (modelId: string) => + API_URL() + 'model/provenance/' + convertSlashInUrl(modelId), GetLocalHFConfig: (modelId: string) => API_URL() + 'model/get_local_hfconfig?model_id=' + modelId, SearchForLocalUninstalledModels: (path: string) =>