Skip to content
Snippets Groups Projects
Commit 00e41abd authored by deep1401's avatar deep1401
Browse files

Add Swap Axes logic for line chart

parent eac3333c
No related branches found
No related tags found
No related merge requests found
//// filepath: /Users/deep.gandhi/transformerlab-repos/transformerlab-app/src/renderer/components/Experiment/Eval/Chart.tsx
import React, { useState } from 'react'; import React, { useState } from 'react';
import { ResponsiveLine } from '@nivo/line'; import { ResponsiveLine } from '@nivo/line';
import { ResponsiveBar } from '@nivo/bar'; import { ResponsiveBar } from '@nivo/bar';
import { ResponsiveRadar } from '@nivo/radar'; import { ResponsiveRadar } from '@nivo/radar';
import { Select, Option, FormControl } from '@mui/joy'; import { Select, Option, FormControl, Box, Button } from '@mui/joy';
import { ArrowLeftRight } from 'lucide-react';
const Chart = ({ metrics, compareChart }) => { const Chart = ({ metrics, compareChart }) => {
const [chartType, setChartType] = useState('bar'); const [chartType, setChartType] = useState('bar');
const [swapAxes, setSwapAxes] = useState(false);
const handleChartTypeChange = (event, newValue) => { const handleChartTypeChange = (event, newValue) => {
setChartType(newValue); setChartType(newValue);
}; };
const handleSwapAxes = () => {
setSwapAxes(!swapAxes);
};
if (!metrics || metrics.length === 0) { if (!metrics || metrics.length === 0) {
return <div>No metrics available</div>; return <div>No metrics available</div>;
} }
...@@ -20,13 +27,42 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -20,13 +27,42 @@ const Chart = ({ metrics, compareChart }) => {
let barData, lineData, radarData; let barData, lineData, radarData;
if (compareChart) { if (compareChart) {
// For compare mode, multiple evaluators/jobs become separate series. // For compare mode, we need to handle the axis swap
// Use evaluator-job as series key. const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
const seriesKeys = Array.from( const seriesKeys = Array.from(
new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)) new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))
); );
// For Bar chart: group by metric type. // For Line chart: handle axis swapping
if (swapAxes && chartType === 'line') {
// Swapped axes: each series is a metric type
const seriesDataMap = {};
metricTypes.forEach(type => {
seriesDataMap[type] = { id: type, data: [] };
});
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[type].data.push({ x: seriesKey, y: score });
});
lineData = Object.values(seriesDataMap);
} else {
// Normal mode for line chart: each series is an evaluator-job
const seriesDataMap = {};
seriesKeys.forEach(series => {
seriesDataMap[series] = { id: series, data: [] };
});
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[seriesKey].data.push({ x: type, y: score });
});
lineData = Object.values(seriesDataMap);
}
// Bar chart: data preparation (unchanged)
const barDataMap = {}; const barDataMap = {};
metrics.forEach(metric => { metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric; const { type, evaluator, job_id, score } = metric;
...@@ -38,19 +74,7 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -38,19 +74,7 @@ const Chart = ({ metrics, compareChart }) => {
}); });
barData = Object.values(barDataMap); barData = Object.values(barDataMap);
// For Line chart: each series is an evaluator-job. // Radar chart: data preparation (unchanged)
const seriesDataMap = {};
seriesKeys.forEach(series => {
seriesDataMap[series] = { id: series, data: [] };
});
metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = `${evaluator}-${job_id}`;
seriesDataMap[seriesKey].data.push({ x: type, y: score });
});
lineData = Object.values(seriesDataMap);
// For Radar chart: similar to bar, but keys represent evaluator-job score.
const radarDataMap = {}; const radarDataMap = {};
metrics.forEach(metric => { metrics.forEach(metric => {
const { type, evaluator, job_id, score } = metric; const { type, evaluator, job_id, score } = metric;
...@@ -62,7 +86,7 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -62,7 +86,7 @@ const Chart = ({ metrics, compareChart }) => {
}); });
radarData = Object.values(radarDataMap); radarData = Object.values(radarDataMap);
} else { } else {
// Original logic: assume a single series. // Original logic for non-compare mode (unchanged)
barData = metrics.map(metric => ({ barData = metrics.map(metric => ({
type: metric.type, type: metric.type,
score: metric.score, score: metric.score,
...@@ -83,13 +107,26 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -83,13 +107,26 @@ const Chart = ({ metrics, compareChart }) => {
return ( return (
<> <>
<FormControl sx={{ width: 200 }}> <Box sx={{ display: 'flex', gap: 2, mb: 2, alignItems: 'center' }}>
<Select value={chartType} onChange={handleChartTypeChange}> <FormControl sx={{ width: 200 }}>
<Option value="bar">Bar</Option> <Select value={chartType} onChange={handleChartTypeChange}>
<Option value="line">Line</Option> <Option value="bar">Bar</Option>
<Option value="radar">Radar</Option> <Option value="line">Line</Option>
</Select> <Option value="radar">Radar</Option>
</FormControl> </Select>
</FormControl>
{compareChart && chartType === 'line' && (
<Button
variant="outlined"
startDecorator={<ArrowLeftRight size={18} />}
onClick={handleSwapAxes}
>
Swap Axes
</Button>
)}
</Box>
<div style={{ height: 400, width: '100%' }}> <div style={{ height: 400, width: '100%' }}>
{chartType === 'line' && ( {chartType === 'line' && (
<ResponsiveLine <ResponsiveLine
...@@ -100,7 +137,7 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -100,7 +137,7 @@ const Chart = ({ metrics, compareChart }) => {
type: 'linear', type: 'linear',
min: 'auto', min: 'auto',
max: 'auto', max: 'auto',
stacked: true, stacked: false,
reverse: false, reverse: false,
}} }}
axisTop={null} axisTop={null}
...@@ -108,8 +145,8 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -108,8 +145,8 @@ const Chart = ({ metrics, compareChart }) => {
axisBottom={{ axisBottom={{
tickSize: 5, tickSize: 5,
tickPadding: 5, tickPadding: 5,
tickRotation: 0, tickRotation: swapAxes && compareChart ? 45 : 0,
legend: 'metric', legend: swapAxes && compareChart ? 'experiment' : 'metric',
legendOffset: 36, legendOffset: 36,
legendPosition: 'middle', legendPosition: 'middle',
}} }}
...@@ -121,9 +158,31 @@ const Chart = ({ metrics, compareChart }) => { ...@@ -121,9 +158,31 @@ const Chart = ({ metrics, compareChart }) => {
legendOffset: -40, legendOffset: -40,
legendPosition: 'middle', legendPosition: 'middle',
}} }}
pointSize={10}
pointColor={{ theme: 'background' }}
pointBorderWidth={2}
pointBorderColor={{ from: 'serieColor' }}
legends={[
{
anchor: 'bottom-right',
direction: 'column',
justify: false,
translateX: 100,
translateY: 0,
itemsSpacing: 0,
itemDirection: 'left-to-right',
itemWidth: 80,
itemHeight: 20,
itemOpacity: 0.75,
symbolSize: 12,
symbolShape: 'circle',
symbolBorderColor: 'rgba(0, 0, 0, .5)',
}
]}
/> />
)} )}
{chartType === 'bar' && (
{chartType === 'bar' && (
<ResponsiveBar <ResponsiveBar
data={barData} data={barData}
keys={ keys={
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment