Skip to content
Snippets Groups Projects
Commit 7a815f92 authored by deep1401's avatar deep1401
Browse files

Refactor code to use minimum of compareChart

parent 00e41abd
No related branches found
No related tags found
No related merge requests found
......@@ -5,12 +5,10 @@ import { ResponsiveRadar } from '@nivo/radar';
import { Select, Option, FormControl, Box, Button } from '@mui/joy';
import { ArrowLeftRight } from 'lucide-react';
const Chart = ({ metrics, compareChart }) => {
const [chartType, setChartType] = useState('bar');
const [swapAxes, setSwapAxes] = useState(false);
const handleChartTypeChange = (event, newValue) => {
setChartType(newValue);
};
......@@ -23,87 +21,65 @@ const Chart = ({ metrics, compareChart }) => {
return <div>No metrics available</div>;
}
// Normalize data structure regardless of compare mode
// Get all unique metric types and series keys
const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
const seriesKeys = Array.from(
new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score'))
);
let barData, lineData, radarData;
if (compareChart) {
// For compare mode, we need to handle the axis swap
const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
const seriesKeys = Array.from(
new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))
);
// For Line chart: handle axis swapping
if (swapAxes && chartType === 'line') {
// Swapped axes: each series is a metric type
const seriesDataMap = {};
metricTypes.forEach(type => {
seriesDataMap[type] = { id: type, data: [] };
});
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[type].data.push({ x: seriesKey, y: score });
});
lineData = Object.values(seriesDataMap);
} else {
// Normal mode for line chart: each series is an evaluator-job
const seriesDataMap = {};
seriesKeys.forEach(series => {
seriesDataMap[series] = { id: series, data: [] };
});
// Create a consistent structure for all modes
const dataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score';
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[seriesKey].data.push({ x: type, y: score });
});
lineData = Object.values(seriesDataMap);
if (!dataMap[type]) {
dataMap[type] = { metric: type };
}
dataMap[type][seriesKey] = score;
});
// Bar chart: data preparation (unchanged)
const barDataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
if (!barDataMap[type]) {
barDataMap[type] = { type };
}
barDataMap[type][seriesKey] = score;
});
barData = Object.values(barDataMap);
// Radar chart: data preparation (unchanged)
const radarDataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
if (!radarDataMap[type]) {
radarDataMap[type] = { metric: type };
}
radarDataMap[type][seriesKey] = score;
});
radarData = Object.values(radarDataMap);
} else {
// Original logic for non-compare mode (unchanged)
barData = metrics.map(metric => ({
type: metric.type,
score: metric.score,
}));
const normalizedData = {
dataPoints: Object.values(dataMap),
metricTypes,
seriesKeys
};
lineData = [
{
id: 'metrics',
data: metrics.map(metric => ({ x: metric.type, y: metric.score })),
// Get transformed data for the right chart type
const getChartData = () => {
const { dataPoints, metricTypes, seriesKeys } = normalizedData;
if (chartType === 'line') {
if (swapAxes) {
// Series are metric types
return metricTypes.map(type => ({
id: type,
data: seriesKeys.map(seriesKey => {
const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined);
return {
x: seriesKey,
y: matchingPoint ? matchingPoint[seriesKey] : null
};
}).filter(point => point.y !== null)
}));
} else {
// Series are evaluators/jobs
return seriesKeys.map(seriesKey => ({
id: seriesKey,
data: dataPoints.map(point => ({
x: point.metric,
y: point[seriesKey]
})).filter(point => point.y !== undefined)
}));
}
];
radarData = metrics.map(metric => ({
metric: metric.type,
score: metric.score,
}));
}
} else if (chartType === 'bar' || chartType === 'radar') {
// Both bar and radar can use the dataPoints directly
return dataPoints;
} else {
return [];
}
};
return (
<>
......@@ -116,7 +92,7 @@ const Chart = ({ metrics, compareChart }) => {
</Select>
</FormControl>
{compareChart && chartType === 'line' && (
{chartType === 'line' && (
<Button
variant="outlined"
startDecorator={<ArrowLeftRight size={18} />}
......@@ -130,7 +106,7 @@ const Chart = ({ metrics, compareChart }) => {
<div style={{ height: 400, width: '100%' }}>
{chartType === 'line' && (
<ResponsiveLine
data={lineData}
data={getChartData()}
margin={{ top: 50, right: 110, bottom: 50, left: 60 }}
xScale={{ type: 'point' }}
yScale={{
......@@ -145,8 +121,8 @@ const Chart = ({ metrics, compareChart }) => {
axisBottom={{
tickSize: 5,
tickPadding: 5,
tickRotation: swapAxes && compareChart ? 45 : 0,
legend: swapAxes && compareChart ? 'experiment' : 'metric',
tickRotation: swapAxes ? 45 : 0,
legend: swapAxes ? 'experiment' : 'metric',
legendOffset: 36,
legendPosition: 'middle',
}}
......@@ -182,18 +158,14 @@ const Chart = ({ metrics, compareChart }) => {
/>
)}
{chartType === 'bar' && (
{chartType === 'bar' && (
<ResponsiveBar
data={barData}
keys={
compareChart
? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`)))
: ['score']
}
indexBy="type"
data={getChartData()}
keys={normalizedData.seriesKeys}
indexBy="metric"
margin={{ top: 50, right: 130, bottom: 50, left: 60 }}
padding={0.3}
groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here.
groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'}
axisTop={null}
axisRight={null}
axisBottom={{
......@@ -217,85 +189,50 @@ const Chart = ({ metrics, compareChart }) => {
animate={false}
/>
)}
{chartType === 'radar' && compareChart && (
<ResponsiveRadar
data={(() => {
// Group metrics by type
const metricsGroupedByType = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
if (!metricsGroupedByType[type]) {
metricsGroupedByType[type] = { metric: type };
}
metricsGroupedByType[type][seriesKey] = score;
});
return Object.values(metricsGroupedByType);
})()}
keys={Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)))}
indexBy="metric"
margin={{ top: 70, right: 170, bottom: 40, left: 80 }}
borderColor={{ from: 'color' }}
gridShape="circular"
gridLabelOffset={36}
dotSize={10}
dotColor={{ theme: 'background' }}
dotBorderWidth={2}
colors={{ scheme: 'nivo' }}
fillOpacity={0.25}
blendMode="multiply"
animate={true}
motionConfig="wobbly"
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
legends={[
{
anchor: 'right',
direction: 'column',
translateX: 50,
translateY: 0,
itemWidth: 120,
itemHeight: 20,
itemTextColor: '#999',
symbolSize: 12,
symbolShape: 'circle',
effects: [
{
on: 'hover',
style: {
itemTextColor: '#000'
}
}
]
}
]}
/>
)}
{chartType === 'radar' && !compareChart && (
<ResponsiveRadar
data={radarData}
keys={['score']}
indexBy="metric"
margin={{ top: 70, right: 80, bottom: 40, left: 80 }}
gridShape="circular"
gridLabelOffset={36}
dotSize={10}
dotColor={{ from: 'color', modifiers: [] }}
dotBorderWidth={2}
dotBorderColor={{ from: 'color', modifiers: [] }}
colors={{ scheme: 'nivo' }}
fillOpacity={0.25}
blendMode="multiply"
animate={true}
motionConfig="wobbly"
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
/>
)}
{chartType === 'radar' && (
<ResponsiveRadar
data={getChartData()}
keys={normalizedData.seriesKeys}
indexBy="metric"
margin={{ top: 70, right: 170, bottom: 40, left: 80 }}
borderColor={{ from: 'color' }}
gridShape="circular"
gridLabelOffset={36}
dotSize={10}
dotColor={{ theme: 'background' }}
dotBorderWidth={2}
colors={{ scheme: 'nivo' }}
fillOpacity={0.25}
blendMode="multiply"
animate={true}
motionConfig="wobbly"
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
legends={normalizedData.seriesKeys.length > 1 ? [
{
anchor: 'right',
direction: 'column',
translateX: 50,
translateY: 0,
itemWidth: 120,
itemHeight: 20,
itemTextColor: '#999',
symbolSize: 12,
symbolShape: 'circle',
effects: [
{
on: 'hover',
style: {
itemTextColor: '#000'
}
}
]
}
] : []}
/>
)}
</div>
</>
);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment