From 7a815f92a58fbfb0853eb91fb5468d20a10bbde5 Mon Sep 17 00:00:00 2001 From: deep1401 <gandhi0869@gmail.com> Date: Mon, 3 Mar 2025 09:48:58 -0800 Subject: [PATCH] Refactor code to use minimum of compareChart --- .../components/Experiment/Eval/Chart.tsx | 273 +++++++----------- 1 file changed, 105 insertions(+), 168 deletions(-) diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx index 2de89c99..8bb40e9d 100644 --- a/src/renderer/components/Experiment/Eval/Chart.tsx +++ b/src/renderer/components/Experiment/Eval/Chart.tsx @@ -5,12 +5,10 @@ import { ResponsiveRadar } from '@nivo/radar'; import { Select, Option, FormControl, Box, Button } from '@mui/joy'; import { ArrowLeftRight } from 'lucide-react'; - const Chart = ({ metrics, compareChart }) => { const [chartType, setChartType] = useState('bar'); const [swapAxes, setSwapAxes] = useState(false); - const handleChartTypeChange = (event, newValue) => { setChartType(newValue); }; @@ -23,87 +21,65 @@ const Chart = ({ metrics, compareChart }) => { return <div>No metrics available</div>; } + // Normalize data structure regardless of compare mode + // Get all unique metric types and series keys + const metricTypes = Array.from(new Set(metrics.map(m => m.type))); + const seriesKeys = Array.from( + new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score')) + ); - let barData, lineData, radarData; - - if (compareChart) { - // For compare mode, we need to handle the axis swap - const metricTypes = Array.from(new Set(metrics.map(m => m.type))); - const seriesKeys = Array.from( - new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)) - ); - - // For Line chart: handle axis swapping - if (swapAxes && chartType === 'line') { - // Swapped axes: each series is a metric type - const seriesDataMap = {}; - metricTypes.forEach(type => { - seriesDataMap[type] = { id: type, data: [] }; - }); - - metrics.forEach(metric => { - const { type, evaluator, job_id, score } = metric; - const seriesKey = `${evaluator}-${job_id}`; - seriesDataMap[type].data.push({ x: seriesKey, y: score }); - }); - lineData = Object.values(seriesDataMap); - } else { - // Normal mode for line chart: each series is an evaluator-job - const seriesDataMap = {}; - seriesKeys.forEach(series => { - seriesDataMap[series] = { id: series, data: [] }; - }); + // Create a consistent structure for all modes + const dataMap = {}; + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score'; - metrics.forEach(metric => { - const { type, evaluator, job_id, score } = metric; - const seriesKey = `${evaluator}-${job_id}`; - seriesDataMap[seriesKey].data.push({ x: type, y: score }); - }); - lineData = Object.values(seriesDataMap); + if (!dataMap[type]) { + dataMap[type] = { metric: type }; } + dataMap[type][seriesKey] = score; + }); - // Bar chart: data preparation (unchanged) - const barDataMap = {}; - metrics.forEach(metric => { - const { type, evaluator, job_id, score } = metric; - const seriesKey = `${evaluator}-${job_id}`; - if (!barDataMap[type]) { - barDataMap[type] = { type }; - } - barDataMap[type][seriesKey] = score; - }); - barData = Object.values(barDataMap); - - // Radar chart: data preparation (unchanged) - const radarDataMap = {}; - metrics.forEach(metric => { - const { type, evaluator, job_id, score } = metric; - const seriesKey = `${evaluator}-${job_id}`; - if (!radarDataMap[type]) { - radarDataMap[type] = { metric: type }; - } - radarDataMap[type][seriesKey] = score; - }); - radarData = Object.values(radarDataMap); - } else { - // Original logic for non-compare mode (unchanged) - barData = metrics.map(metric => ({ - type: metric.type, - score: metric.score, - })); + const normalizedData = { + dataPoints: Object.values(dataMap), + metricTypes, + seriesKeys + }; - lineData = [ - { - id: 'metrics', - data: metrics.map(metric => ({ x: metric.type, y: metric.score })), + // Get transformed data for the right chart type + const getChartData = () => { + const { dataPoints, metricTypes, seriesKeys } = normalizedData; + + if (chartType === 'line') { + if (swapAxes) { + // Series are metric types + return metricTypes.map(type => ({ + id: type, + data: seriesKeys.map(seriesKey => { + const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined); + return { + x: seriesKey, + y: matchingPoint ? matchingPoint[seriesKey] : null + }; + }).filter(point => point.y !== null) + })); + } else { + // Series are evaluators/jobs + return seriesKeys.map(seriesKey => ({ + id: seriesKey, + data: dataPoints.map(point => ({ + x: point.metric, + y: point[seriesKey] + })).filter(point => point.y !== undefined) + })); } - ]; - - radarData = metrics.map(metric => ({ - metric: metric.type, - score: metric.score, - })); - } + } else if (chartType === 'bar' || chartType === 'radar') { + // Both bar and radar can use the dataPoints directly + return dataPoints; + } else { + return []; + } + }; return ( <> @@ -116,7 +92,7 @@ const Chart = ({ metrics, compareChart }) => { </Select> </FormControl> - {compareChart && chartType === 'line' && ( + {chartType === 'line' && ( <Button variant="outlined" startDecorator={<ArrowLeftRight size={18} />} @@ -130,7 +106,7 @@ const Chart = ({ metrics, compareChart }) => { <div style={{ height: 400, width: '100%' }}> {chartType === 'line' && ( <ResponsiveLine - data={lineData} + data={getChartData()} margin={{ top: 50, right: 110, bottom: 50, left: 60 }} xScale={{ type: 'point' }} yScale={{ @@ -145,8 +121,8 @@ const Chart = ({ metrics, compareChart }) => { axisBottom={{ tickSize: 5, tickPadding: 5, - tickRotation: swapAxes && compareChart ? 45 : 0, - legend: swapAxes && compareChart ? 'experiment' : 'metric', + tickRotation: swapAxes ? 45 : 0, + legend: swapAxes ? 'experiment' : 'metric', legendOffset: 36, legendPosition: 'middle', }} @@ -182,18 +158,14 @@ const Chart = ({ metrics, compareChart }) => { /> )} -{chartType === 'bar' && ( + {chartType === 'bar' && ( <ResponsiveBar - data={barData} - keys={ - compareChart - ? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`))) - : ['score'] - } - indexBy="type" + data={getChartData()} + keys={normalizedData.seriesKeys} + indexBy="metric" margin={{ top: 50, right: 130, bottom: 50, left: 60 }} padding={0.3} - groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here. + groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'} axisTop={null} axisRight={null} axisBottom={{ @@ -217,85 +189,50 @@ const Chart = ({ metrics, compareChart }) => { animate={false} /> )} -{chartType === 'radar' && compareChart && ( - <ResponsiveRadar - data={(() => { - // Group metrics by type - const metricsGroupedByType = {}; - metrics.forEach(metric => { - const { type, evaluator, job_id, score } = metric; - const seriesKey = `${evaluator}-${job_id}`; - - if (!metricsGroupedByType[type]) { - metricsGroupedByType[type] = { metric: type }; - } - metricsGroupedByType[type][seriesKey] = score; - }); - return Object.values(metricsGroupedByType); - })()} - keys={Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)))} - indexBy="metric" - margin={{ top: 70, right: 170, bottom: 40, left: 80 }} - borderColor={{ from: 'color' }} - gridShape="circular" - gridLabelOffset={36} - dotSize={10} - dotColor={{ theme: 'background' }} - dotBorderWidth={2} - colors={{ scheme: 'nivo' }} - fillOpacity={0.25} - blendMode="multiply" - animate={true} - motionConfig="wobbly" - enableDotLabel={true} - dotLabel="value" - dotLabelYOffset={-12} - legends={[ - { - anchor: 'right', - direction: 'column', - translateX: 50, - translateY: 0, - itemWidth: 120, - itemHeight: 20, - itemTextColor: '#999', - symbolSize: 12, - symbolShape: 'circle', - effects: [ - { - on: 'hover', - style: { - itemTextColor: '#000' - } - } - ] - } - ]} - /> -)} -{chartType === 'radar' && !compareChart && ( - <ResponsiveRadar - data={radarData} - keys={['score']} - indexBy="metric" - margin={{ top: 70, right: 80, bottom: 40, left: 80 }} - gridShape="circular" - gridLabelOffset={36} - dotSize={10} - dotColor={{ from: 'color', modifiers: [] }} - dotBorderWidth={2} - dotBorderColor={{ from: 'color', modifiers: [] }} - colors={{ scheme: 'nivo' }} - fillOpacity={0.25} - blendMode="multiply" - animate={true} - motionConfig="wobbly" - enableDotLabel={true} - dotLabel="value" - dotLabelYOffset={-12} - /> -)} + {chartType === 'radar' && ( + <ResponsiveRadar + data={getChartData()} + keys={normalizedData.seriesKeys} + indexBy="metric" + margin={{ top: 70, right: 170, bottom: 40, left: 80 }} + borderColor={{ from: 'color' }} + gridShape="circular" + gridLabelOffset={36} + dotSize={10} + dotColor={{ theme: 'background' }} + dotBorderWidth={2} + colors={{ scheme: 'nivo' }} + fillOpacity={0.25} + blendMode="multiply" + animate={true} + motionConfig="wobbly" + enableDotLabel={true} + dotLabel="value" + dotLabelYOffset={-12} + legends={normalizedData.seriesKeys.length > 1 ? [ + { + anchor: 'right', + direction: 'column', + translateX: 50, + translateY: 0, + itemWidth: 120, + itemHeight: 20, + itemTextColor: '#999', + symbolSize: 12, + symbolShape: 'circle', + effects: [ + { + on: 'hover', + style: { + itemTextColor: '#000' + } + } + ] + } + ] : []} + /> + )} </div> </> ); -- GitLab