From 7a815f92a58fbfb0853eb91fb5468d20a10bbde5 Mon Sep 17 00:00:00 2001
From: deep1401 <gandhi0869@gmail.com>
Date: Mon, 3 Mar 2025 09:48:58 -0800
Subject: [PATCH] Refactor code to use minimum of compareChart

---
 .../components/Experiment/Eval/Chart.tsx      | 273 +++++++-----------
 1 file changed, 105 insertions(+), 168 deletions(-)

diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx
index 2de89c99..8bb40e9d 100644
--- a/src/renderer/components/Experiment/Eval/Chart.tsx
+++ b/src/renderer/components/Experiment/Eval/Chart.tsx
@@ -5,12 +5,10 @@ import { ResponsiveRadar } from '@nivo/radar';
 import { Select, Option, FormControl, Box, Button } from '@mui/joy';
 import { ArrowLeftRight } from 'lucide-react';
 
-
 const Chart = ({ metrics, compareChart }) => {
   const [chartType, setChartType] = useState('bar');
   const [swapAxes, setSwapAxes] = useState(false);
 
-
   const handleChartTypeChange = (event, newValue) => {
     setChartType(newValue);
   };
@@ -23,87 +21,65 @@ const Chart = ({ metrics, compareChart }) => {
     return <div>No metrics available</div>;
   }
 
+  // Normalize data structure regardless of compare mode
+  // Get all unique metric types and series keys
+  const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
+  const seriesKeys = Array.from(
+    new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score'))
+  );
 
-  let barData, lineData, radarData;
-
-  if (compareChart) {
-    // For compare mode, we need to handle the axis swap
-    const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
-    const seriesKeys = Array.from(
-      new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))
-    );
-
-    // For Line chart: handle axis swapping
-    if (swapAxes && chartType === 'line') {
-      // Swapped axes: each series is a metric type
-      const seriesDataMap = {};
-      metricTypes.forEach(type => {
-        seriesDataMap[type] = { id: type, data: [] };
-      });
-
-      metrics.forEach(metric => {
-        const { type, evaluator, job_id, score } = metric;
-        const seriesKey = `${evaluator}-${job_id}`;
-        seriesDataMap[type].data.push({ x: seriesKey, y: score });
-      });
-      lineData = Object.values(seriesDataMap);
-    } else {
-      // Normal mode for line chart: each series is an evaluator-job
-      const seriesDataMap = {};
-      seriesKeys.forEach(series => {
-        seriesDataMap[series] = { id: series, data: [] };
-      });
+  // Create a consistent structure for all modes
+  const dataMap = {};
+  metrics.forEach(metric => {
+    const { type, evaluator, job_id, score } = metric;
+    const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score';
 
-      metrics.forEach(metric => {
-        const { type, evaluator, job_id, score } = metric;
-        const seriesKey = `${evaluator}-${job_id}`;
-        seriesDataMap[seriesKey].data.push({ x: type, y: score });
-      });
-      lineData = Object.values(seriesDataMap);
+    if (!dataMap[type]) {
+      dataMap[type] = { metric: type };
     }
+    dataMap[type][seriesKey] = score;
+  });
 
-    // Bar chart: data preparation (unchanged)
-    const barDataMap = {};
-    metrics.forEach(metric => {
-      const { type, evaluator, job_id, score } = metric;
-      const seriesKey = `${evaluator}-${job_id}`;
-      if (!barDataMap[type]) {
-        barDataMap[type] = { type };
-      }
-      barDataMap[type][seriesKey] = score;
-    });
-    barData = Object.values(barDataMap);
-
-    // Radar chart: data preparation (unchanged)
-    const radarDataMap = {};
-    metrics.forEach(metric => {
-      const { type, evaluator, job_id, score } = metric;
-      const seriesKey = `${evaluator}-${job_id}`;
-      if (!radarDataMap[type]) {
-        radarDataMap[type] = { metric: type };
-      }
-      radarDataMap[type][seriesKey] = score;
-    });
-    radarData = Object.values(radarDataMap);
-  } else {
-    // Original logic for non-compare mode (unchanged)
-    barData = metrics.map(metric => ({
-      type: metric.type,
-      score: metric.score,
-    }));
+  const normalizedData = {
+    dataPoints: Object.values(dataMap),
+    metricTypes,
+    seriesKeys
+  };
 
-    lineData = [
-      {
-        id: 'metrics',
-        data: metrics.map(metric => ({ x: metric.type, y: metric.score })),
+  // Get transformed data for the right chart type
+  const getChartData = () => {
+    const { dataPoints, metricTypes, seriesKeys } = normalizedData;
+
+    if (chartType === 'line') {
+      if (swapAxes) {
+        // Series are metric types
+        return metricTypes.map(type => ({
+          id: type,
+          data: seriesKeys.map(seriesKey => {
+            const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined);
+            return {
+              x: seriesKey,
+              y: matchingPoint ? matchingPoint[seriesKey] : null
+            };
+          }).filter(point => point.y !== null)
+        }));
+      } else {
+        // Series are evaluators/jobs
+        return seriesKeys.map(seriesKey => ({
+          id: seriesKey,
+          data: dataPoints.map(point => ({
+            x: point.metric,
+            y: point[seriesKey]
+          })).filter(point => point.y !== undefined)
+        }));
       }
-    ];
-
-    radarData = metrics.map(metric => ({
-      metric: metric.type,
-      score: metric.score,
-    }));
-  }
+    } else if (chartType === 'bar' || chartType === 'radar') {
+      // Both bar and radar can use the dataPoints directly
+      return dataPoints;
+    } else {
+      return [];
+    }
+  };
 
   return (
     <>
@@ -116,7 +92,7 @@ const Chart = ({ metrics, compareChart }) => {
           </Select>
         </FormControl>
 
-        {compareChart && chartType === 'line' && (
+        {chartType === 'line' && (
           <Button
             variant="outlined"
             startDecorator={<ArrowLeftRight size={18} />}
@@ -130,7 +106,7 @@ const Chart = ({ metrics, compareChart }) => {
       <div style={{ height: 400, width: '100%' }}>
         {chartType === 'line' && (
           <ResponsiveLine
-            data={lineData}
+            data={getChartData()}
             margin={{ top: 50, right: 110, bottom: 50, left: 60 }}
             xScale={{ type: 'point' }}
             yScale={{
@@ -145,8 +121,8 @@ const Chart = ({ metrics, compareChart }) => {
             axisBottom={{
               tickSize: 5,
               tickPadding: 5,
-              tickRotation: swapAxes && compareChart ? 45 : 0,
-              legend: swapAxes && compareChart ? 'experiment' : 'metric',
+              tickRotation: swapAxes ? 45 : 0,
+              legend: swapAxes ? 'experiment' : 'metric',
               legendOffset: 36,
               legendPosition: 'middle',
             }}
@@ -182,18 +158,14 @@ const Chart = ({ metrics, compareChart }) => {
           />
         )}
 
-{chartType === 'bar' && (
+        {chartType === 'bar' && (
           <ResponsiveBar
-            data={barData}
-            keys={
-              compareChart
-                ? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`)))
-                : ['score']
-            }
-            indexBy="type"
+            data={getChartData()}
+            keys={normalizedData.seriesKeys}
+            indexBy="metric"
             margin={{ top: 50, right: 130, bottom: 50, left: 60 }}
             padding={0.3}
-            groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here.
+            groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'}
             axisTop={null}
             axisRight={null}
             axisBottom={{
@@ -217,85 +189,50 @@ const Chart = ({ metrics, compareChart }) => {
             animate={false}
           />
         )}
-{chartType === 'radar' && compareChart && (
-  <ResponsiveRadar
-    data={(() => {
-      // Group metrics by type
-      const metricsGroupedByType = {};
-      metrics.forEach(metric => {
-        const { type, evaluator, job_id, score } = metric;
-        const seriesKey = `${evaluator}-${job_id}`;
-
-        if (!metricsGroupedByType[type]) {
-          metricsGroupedByType[type] = { metric: type };
-        }
-        metricsGroupedByType[type][seriesKey] = score;
-      });
-      return Object.values(metricsGroupedByType);
-    })()}
-    keys={Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)))}
-    indexBy="metric"
-    margin={{ top: 70, right: 170, bottom: 40, left: 80 }}
-    borderColor={{ from: 'color' }}
-    gridShape="circular"
-    gridLabelOffset={36}
-    dotSize={10}
-    dotColor={{ theme: 'background' }}
-    dotBorderWidth={2}
-    colors={{ scheme: 'nivo' }}
-    fillOpacity={0.25}
-    blendMode="multiply"
-    animate={true}
-    motionConfig="wobbly"
-    enableDotLabel={true}
-    dotLabel="value"
-    dotLabelYOffset={-12}
-    legends={[
-      {
-        anchor: 'right',
-        direction: 'column',
-        translateX: 50,
-        translateY: 0,
-        itemWidth: 120,
-        itemHeight: 20,
-        itemTextColor: '#999',
-        symbolSize: 12,
-        symbolShape: 'circle',
-        effects: [
-          {
-            on: 'hover',
-            style: {
-              itemTextColor: '#000'
-            }
-          }
-        ]
-      }
-    ]}
-  />
-)}
 
-{chartType === 'radar' && !compareChart && (
-  <ResponsiveRadar
-    data={radarData}
-    keys={['score']}
-    indexBy="metric"
-    margin={{ top: 70, right: 80, bottom: 40, left: 80 }}
-    gridShape="circular"
-    gridLabelOffset={36}
-    dotSize={10}
-    dotColor={{ from: 'color', modifiers: [] }}
-    dotBorderWidth={2}
-    dotBorderColor={{ from: 'color', modifiers: [] }}
-    colors={{ scheme: 'nivo' }}
-    fillOpacity={0.25}
-    blendMode="multiply"
-    animate={true}
-    motionConfig="wobbly"
-    enableDotLabel={true}
-    dotLabel="value"
-    dotLabelYOffset={-12}
-  />
-)}
+        {chartType === 'radar' && (
+          <ResponsiveRadar
+            data={getChartData()}
+            keys={normalizedData.seriesKeys}
+            indexBy="metric"
+            margin={{ top: 70, right: 170, bottom: 40, left: 80 }}
+            borderColor={{ from: 'color' }}
+            gridShape="circular"
+            gridLabelOffset={36}
+            dotSize={10}
+            dotColor={{ theme: 'background' }}
+            dotBorderWidth={2}
+            colors={{ scheme: 'nivo' }}
+            fillOpacity={0.25}
+            blendMode="multiply"
+            animate={true}
+            motionConfig="wobbly"
+            enableDotLabel={true}
+            dotLabel="value"
+            dotLabelYOffset={-12}
+            legends={normalizedData.seriesKeys.length > 1 ? [
+              {
+                anchor: 'right',
+                direction: 'column',
+                translateX: 50,
+                translateY: 0,
+                itemWidth: 120,
+                itemHeight: 20,
+                itemTextColor: '#999',
+                symbolSize: 12,
+                symbolShape: 'circle',
+                effects: [
+                  {
+                    on: 'hover',
+                    style: {
+                      itemTextColor: '#000'
+                    }
+                  }
+                ]
+              }
+            ] : []}
+          />
+        )}
       </div>
     </>
   );
-- 
GitLab