diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx index 8bb40e9d7247ad3be5a06754a035bcac0d088ab4..d3de4b8acd2f7af937b7df9dfe66ee7e55c75b03 100644 --- a/src/renderer/components/Experiment/Eval/Chart.tsx +++ b/src/renderer/components/Experiment/Eval/Chart.tsx @@ -23,14 +23,18 @@ const Chart = ({ metrics, compareChart }) => { // Normalize data structure regardless of compare mode // Get all unique metric types and series keys - const metricTypes = Array.from(new Set(metrics.map(m => m.type))); + const metricTypes = Array.from(new Set(metrics.map((m) => m.type))); const seriesKeys = Array.from( - new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score')) + new Set( + metrics.map((m) => + compareChart ? `${m.evaluator}-${m.job_id}` : 'score', + ), + ), ); // Create a consistent structure for all modes const dataMap = {}; - metrics.forEach(metric => { + metrics.forEach((metric) => { const { type, evaluator, job_id, score } = metric; const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score'; @@ -43,7 +47,7 @@ const Chart = ({ metrics, compareChart }) => { const normalizedData = { dataPoints: Object.values(dataMap), metricTypes, - seriesKeys + seriesKeys, }; // Get transformed data for the right chart type @@ -53,24 +57,30 @@ const Chart = ({ metrics, compareChart }) => { if (chartType === 'line') { if (swapAxes) { // Series are metric types - return metricTypes.map(type => ({ + return metricTypes.map((type) => ({ id: type, - data: seriesKeys.map(seriesKey => { - const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined); - return { - x: seriesKey, - y: matchingPoint ? matchingPoint[seriesKey] : null - }; - }).filter(point => point.y !== null) + data: seriesKeys + .map((seriesKey) => { + const matchingPoint = dataPoints.find( + (p) => p.metric === type && p[seriesKey] !== undefined, + ); + return { + x: seriesKey, + y: matchingPoint ? matchingPoint[seriesKey] : null, + }; + }) + .filter((point) => point.y !== null), })); } else { // Series are evaluators/jobs - return seriesKeys.map(seriesKey => ({ + return seriesKeys.map((seriesKey) => ({ id: seriesKey, - data: dataPoints.map(point => ({ - x: point.metric, - y: point[seriesKey] - })).filter(point => point.y !== undefined) + data: dataPoints + .map((point) => ({ + x: point.metric, + y: point[seriesKey], + })) + .filter((point) => point.y !== undefined), })); } } else if (chartType === 'bar' || chartType === 'radar') { @@ -82,8 +92,15 @@ const Chart = ({ metrics, compareChart }) => { }; return ( - <> - <Box sx={{ display: 'flex', gap: 2, mb: 2, alignItems: 'center' }}> + <Box sx={{ border: '1px solid #e0e0e0', borderRadius: 8, p: 2 }}> + <Box + sx={{ + display: 'flex', + gap: 2, + mb: 2, + alignItems: 'center', + }} + > <FormControl sx={{ width: 200 }}> <Select value={chartType} onChange={handleChartTypeChange}> <Option value="bar">Bar</Option> @@ -107,7 +124,7 @@ const Chart = ({ metrics, compareChart }) => { {chartType === 'line' && ( <ResponsiveLine data={getChartData()} - margin={{ top: 50, right: 110, bottom: 50, left: 60 }} + margin={{ top: 50, right: 200, bottom: 80, left: 60 }} xScale={{ type: 'point' }} yScale={{ type: 'linear', @@ -140,7 +157,7 @@ const Chart = ({ metrics, compareChart }) => { pointBorderColor={{ from: 'serieColor' }} legends={[ { - anchor: 'bottom-right', + anchor: 'top-right', direction: 'column', justify: false, translateX: 100, @@ -153,7 +170,7 @@ const Chart = ({ metrics, compareChart }) => { symbolSize: 12, symbolShape: 'circle', symbolBorderColor: 'rgba(0, 0, 0, .5)', - } + }, ]} /> )} @@ -165,7 +182,9 @@ const Chart = ({ metrics, compareChart }) => { indexBy="metric" margin={{ top: 50, right: 130, bottom: 50, left: 60 }} padding={0.3} - groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'} + groupMode={ + normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked' + } axisTop={null} axisRight={null} axisBottom={{ @@ -210,31 +229,35 @@ const Chart = ({ metrics, compareChart }) => { enableDotLabel={true} dotLabel="value" dotLabelYOffset={-12} - legends={normalizedData.seriesKeys.length > 1 ? [ - { - anchor: 'right', - direction: 'column', - translateX: 50, - translateY: 0, - itemWidth: 120, - itemHeight: 20, - itemTextColor: '#999', - symbolSize: 12, - symbolShape: 'circle', - effects: [ - { - on: 'hover', - style: { - itemTextColor: '#000' - } - } - ] - } - ] : []} + legends={ + normalizedData.seriesKeys.length > 0 + ? [ + { + anchor: 'top-right', + direction: 'column', + translateX: -200, + translateY: 0, + itemWidth: 120, + itemHeight: 20, + itemTextColor: '#999', + symbolSize: 12, + symbolShape: 'circle', + effects: [ + { + on: 'hover', + style: { + itemTextColor: '#000', + }, + }, + ], + }, + ] + : [] + } /> )} </div> - </> + </Box> ); }; diff --git a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx index 91583c70846f8a3c253fe923231212e0a69ab93b..267812dd7978e25ee09ad033046c2bb001e7eca2 100644 --- a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx @@ -9,7 +9,13 @@ function parseJSON(data) { } } -export default function ViewPlotModal({ open, onClose, data, jobId, compareChart= false}) { +export default function ViewPlotModal({ + open, + onClose, + data, + jobId, + compareChart = false, +}) { if (!jobId) { return <></>; } @@ -29,9 +35,6 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart height: '100%', }} > - <Typography level="h4" mb={2}> - Chart - </Typography> <Box sx={{ width: '100%', @@ -42,9 +45,12 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart borderRadius: '8px', boxShadow: 1, p: 2, + display: 'flex', + flexDirection: 'column', + justifyContent: 'center', }} > - <Chart metrics={parseJSON(data)} compareChart={compareChart} /> + <Chart metrics={parseJSON(data)} compareChart={compareChart} /> </Box> </Box> </ModalDialog>