Skip to content
Snippets Groups Projects
Unverified Commit 36505e2f authored by Deep Gandhi's avatar Deep Gandhi Committed by GitHub
Browse files

Merge pull request #292 from transformerlab/fix/compare-evals-fixes

Fix Compare Evals
parents 9776af4c 7a815f92
No related branches found
No related tags found
No related merge requests found
//// filepath: /Users/deep.gandhi/transformerlab-repos/transformerlab-app/src/renderer/components/Experiment/Eval/Chart.tsx
import React, { useState } from 'react';
import { ResponsiveLine } from '@nivo/line';
import { ResponsiveBar } from '@nivo/bar';
import { ResponsiveRadar } from '@nivo/radar';
import { Select, Option, FormControl } from '@mui/joy';
import { Select, Option, FormControl, Box, Button } from '@mui/joy';
import { ArrowLeftRight } from 'lucide-react';
const Chart = ({ metrics, compareChart }) => {
const [chartType, setChartType] = useState('bar');
const [swapAxes, setSwapAxes] = useState(false);
const handleChartTypeChange = (event, newValue) => {
setChartType(newValue);
};
const handleSwapAxes = () => {
setSwapAxes(!swapAxes);
};
if (!metrics || metrics.length === 0) {
return <div>No metrics available</div>;
}
// Normalize data structure regardless of compare mode
// Get all unique metric types and series keys
const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
const seriesKeys = Array.from(
new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score'))
);
let barData, lineData, radarData;
if (compareChart) {
// For compare mode, multiple evaluators/jobs become separate series.
// Use evaluator-job as series key.
const seriesKeys = Array.from(
new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))
);
// Create a consistent structure for all modes
const dataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score';
if (!dataMap[type]) {
dataMap[type] = { metric: type };
}
dataMap[type][seriesKey] = score;
});
const normalizedData = {
dataPoints: Object.values(dataMap),
metricTypes,
seriesKeys
};
// For Bar chart: group by metric type.
const barDataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
if (!barDataMap[type]) {
barDataMap[type] = { type };
}
barDataMap[type][seriesKey] = score;
});
barData = Object.values(barDataMap);
// For Line chart: each series is an evaluator-job.
const seriesDataMap = {};
seriesKeys.forEach(series => {
seriesDataMap[series] = { id: series, data: [] };
});
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[seriesKey].data.push({ x: type, y: score });
});
lineData = Object.values(seriesDataMap);
// For Radar chart: similar to bar, but keys represent evaluator-job score.
const radarDataMap = {};
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
if (!radarDataMap[type]) {
radarDataMap[type] = { metric: type };
}
radarDataMap[type][seriesKey] = score;
});
radarData = Object.values(radarDataMap);
} else {
// Original logic: assume a single series.
barData = metrics.map(metric => ({
type: metric.type,
score: metric.score,
}));
lineData = [
{
id: 'metrics',
data: metrics.map(metric => ({ x: metric.type, y: metric.score })),
// Get transformed data for the right chart type
const getChartData = () => {
const { dataPoints, metricTypes, seriesKeys } = normalizedData;
if (chartType === 'line') {
if (swapAxes) {
// Series are metric types
return metricTypes.map(type => ({
id: type,
data: seriesKeys.map(seriesKey => {
const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined);
return {
x: seriesKey,
y: matchingPoint ? matchingPoint[seriesKey] : null
};
}).filter(point => point.y !== null)
}));
} else {
// Series are evaluators/jobs
return seriesKeys.map(seriesKey => ({
id: seriesKey,
data: dataPoints.map(point => ({
x: point.metric,
y: point[seriesKey]
})).filter(point => point.y !== undefined)
}));
}
];
radarData = metrics.map(metric => ({
metric: metric.type,
score: metric.score,
}));
}
} else if (chartType === 'bar' || chartType === 'radar') {
// Both bar and radar can use the dataPoints directly
return dataPoints;
} else {
return [];
}
};
return (
<>
<FormControl sx={{ width: 200 }}>
<Select value={chartType} onChange={handleChartTypeChange}>
<Option value="bar">Bar</Option>
<Option value="line">Line</Option>
<Option value="radar">Radar</Option>
</Select>
</FormControl>
<Box sx={{ display: 'flex', gap: 2, mb: 2, alignItems: 'center' }}>
<FormControl sx={{ width: 200 }}>
<Select value={chartType} onChange={handleChartTypeChange}>
<Option value="bar">Bar</Option>
<Option value="line">Line</Option>
<Option value="radar">Radar</Option>
</Select>
</FormControl>
{chartType === 'line' && (
<Button
variant="outlined"
startDecorator={<ArrowLeftRight size={18} />}
onClick={handleSwapAxes}
>
Swap Axes
</Button>
)}
</Box>
<div style={{ height: 400, width: '100%' }}>
{chartType === 'line' && (
<ResponsiveLine
data={lineData}
data={getChartData()}
margin={{ top: 50, right: 110, bottom: 50, left: 60 }}
xScale={{ type: 'point' }}
yScale={{
type: 'linear',
min: 'auto',
max: 'auto',
stacked: true,
stacked: false,
reverse: false,
}}
axisTop={null}
......@@ -108,8 +121,8 @@ const Chart = ({ metrics, compareChart }) => {
axisBottom={{
tickSize: 5,
tickPadding: 5,
tickRotation: 0,
legend: 'metric',
tickRotation: swapAxes ? 45 : 0,
legend: swapAxes ? 'experiment' : 'metric',
legendOffset: 36,
legendPosition: 'middle',
}}
......@@ -121,20 +134,38 @@ const Chart = ({ metrics, compareChart }) => {
legendOffset: -40,
legendPosition: 'middle',
}}
pointSize={10}
pointColor={{ theme: 'background' }}
pointBorderWidth={2}
pointBorderColor={{ from: 'serieColor' }}
legends={[
{
anchor: 'bottom-right',
direction: 'column',
justify: false,
translateX: 100,
translateY: 0,
itemsSpacing: 0,
itemDirection: 'left-to-right',
itemWidth: 80,
itemHeight: 20,
itemOpacity: 0.75,
symbolSize: 12,
symbolShape: 'circle',
symbolBorderColor: 'rgba(0, 0, 0, .5)',
}
]}
/>
)}
{chartType === 'bar' && (
<ResponsiveBar
data={barData}
keys={
compareChart
? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`)))
: ['score']
}
indexBy="type"
data={getChartData()}
keys={normalizedData.seriesKeys}
indexBy="metric"
margin={{ top: 50, right: 130, bottom: 50, left: 60 }}
padding={0.3}
groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here.
groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'}
axisTop={null}
axisRight={null}
axisBottom={{
......@@ -158,46 +189,34 @@ const Chart = ({ metrics, compareChart }) => {
animate={false}
/>
)}
{chartType === 'radar' && compareChart && (
<div style={{ position: 'relative', height: '100%', width: '100%' }}>
{Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))).map((series, index) => {
// Filter radar data for this specific evaluator-job combination
const seriesRadarData = metrics
.filter(m => `${m.evaluator}-${m.job_id}` === series)
.map(metric => ({
metric: metric.type,
score: metric.score
}));
return (
<div key={series} style={{ position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 }}>
{chartType === 'radar' && (
<ResponsiveRadar
data={seriesRadarData}
keys={['score']}
data={getChartData()}
keys={normalizedData.seriesKeys}
indexBy="metric"
margin={{ top: 70, right: 80, bottom: 40, left: 80 }}
margin={{ top: 70, right: 170, bottom: 40, left: 80 }}
borderColor={{ from: 'color' }}
gridShape="circular"
gridLabelOffset={36}
dotSize={10}
dotColor={{ from: 'color' }}
dotColor={{ theme: 'background' }}
dotBorderWidth={2}
dotBorderColor={{ from: 'color' }}
colors={[`hsl(${index * 30 + 60}, 70%, 50%)`]} // Different color for each series
fillOpacity={0.2}
blendMode="normal"
colors={{ scheme: 'nivo' }}
fillOpacity={0.25}
blendMode="multiply"
animate={true}
motionConfig="wobbly"
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
legends={[
legends={normalizedData.seriesKeys.length > 1 ? [
{
anchor: 'top-right',
anchor: 'right',
direction: 'column',
translateX: 0,
translateY: -40,
itemWidth: 80,
translateX: 50,
translateY: 0,
itemWidth: 120,
itemHeight: 20,
itemTextColor: '#999',
symbolSize: 12,
......@@ -209,39 +228,11 @@ const Chart = ({ metrics, compareChart }) => {
itemTextColor: '#000'
}
}
],
data: [{ id: series, label: series, color: `hsl(${index * 30 + 60}, 70%, 50%)` }]
]
}
]}
] : []}
/>
</div>
);
})}
</div>
)}
{chartType === 'radar' && !compareChart && (
<ResponsiveRadar
data={radarData}
keys={['score']}
indexBy="metric"
margin={{ top: 70, right: 80, bottom: 40, left: 80 }}
gridShape="circular"
gridLabelOffset={36}
dotSize={10}
dotColor={{ from: 'color', modifiers: [] }}
dotBorderWidth={2}
dotBorderColor={{ from: 'color', modifiers: [] }}
colors={{ scheme: 'nivo' }}
fillOpacity={0.25}
blendMode="multiply"
animate={true}
motionConfig="wobbly"
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
/>
)}
)}
</div>
</>
);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment