Skip to content
Snippets Groups Projects
Unverified Commit 76a84ae3 authored by ali asaria's avatar ali asaria Committed by GitHub
Browse files

Merge pull request #295 from transformerlab/fix/improve-charts

Fix/improve charts
parents c33f2e40 93d5891f
Branches master
No related tags found
No related merge requests found
......@@ -23,14 +23,18 @@ const Chart = ({ metrics, compareChart }) => {
// Normalize data structure regardless of compare mode
// Get all unique metric types and series keys
const metricTypes = Array.from(new Set(metrics.map(m => m.type)));
const metricTypes = Array.from(new Set(metrics.map((m) => m.type)));
const seriesKeys = Array.from(
new Set(metrics.map(m => compareChart ? `${m.evaluator}-${m.job_id}` : 'score'))
new Set(
metrics.map((m) =>
compareChart ? `${m.evaluator}-${m.job_id}` : 'score',
),
),
);
// Create a consistent structure for all modes
const dataMap = {};
metrics.forEach(metric => {
metrics.forEach((metric) => {
const { type, evaluator, job_id, score } = metric;
const seriesKey = compareChart ? `${evaluator}-${job_id}` : 'score';
......@@ -43,7 +47,7 @@ const Chart = ({ metrics, compareChart }) => {
const normalizedData = {
dataPoints: Object.values(dataMap),
metricTypes,
seriesKeys
seriesKeys,
};
// Get transformed data for the right chart type
......@@ -53,24 +57,30 @@ const Chart = ({ metrics, compareChart }) => {
if (chartType === 'line') {
if (swapAxes) {
// Series are metric types
return metricTypes.map(type => ({
return metricTypes.map((type) => ({
id: type,
data: seriesKeys.map(seriesKey => {
const matchingPoint = dataPoints.find(p => p.metric === type && p[seriesKey] !== undefined);
return {
x: seriesKey,
y: matchingPoint ? matchingPoint[seriesKey] : null
};
}).filter(point => point.y !== null)
data: seriesKeys
.map((seriesKey) => {
const matchingPoint = dataPoints.find(
(p) => p.metric === type && p[seriesKey] !== undefined,
);
return {
x: seriesKey,
y: matchingPoint ? matchingPoint[seriesKey] : null,
};
})
.filter((point) => point.y !== null),
}));
} else {
// Series are evaluators/jobs
return seriesKeys.map(seriesKey => ({
return seriesKeys.map((seriesKey) => ({
id: seriesKey,
data: dataPoints.map(point => ({
x: point.metric,
y: point[seriesKey]
})).filter(point => point.y !== undefined)
data: dataPoints
.map((point) => ({
x: point.metric,
y: point[seriesKey],
}))
.filter((point) => point.y !== undefined),
}));
}
} else if (chartType === 'bar' || chartType === 'radar') {
......@@ -82,8 +92,15 @@ const Chart = ({ metrics, compareChart }) => {
};
return (
<>
<Box sx={{ display: 'flex', gap: 2, mb: 2, alignItems: 'center' }}>
<Box sx={{ border: '1px solid #e0e0e0', borderRadius: 8, p: 2 }}>
<Box
sx={{
display: 'flex',
gap: 2,
mb: 2,
alignItems: 'center',
}}
>
<FormControl sx={{ width: 200 }}>
<Select value={chartType} onChange={handleChartTypeChange}>
<Option value="bar">Bar</Option>
......@@ -107,7 +124,7 @@ const Chart = ({ metrics, compareChart }) => {
{chartType === 'line' && (
<ResponsiveLine
data={getChartData()}
margin={{ top: 50, right: 110, bottom: 50, left: 60 }}
margin={{ top: 50, right: 200, bottom: 80, left: 60 }}
xScale={{ type: 'point' }}
yScale={{
type: 'linear',
......@@ -140,7 +157,7 @@ const Chart = ({ metrics, compareChart }) => {
pointBorderColor={{ from: 'serieColor' }}
legends={[
{
anchor: 'bottom-right',
anchor: 'top-right',
direction: 'column',
justify: false,
translateX: 100,
......@@ -153,7 +170,7 @@ const Chart = ({ metrics, compareChart }) => {
symbolSize: 12,
symbolShape: 'circle',
symbolBorderColor: 'rgba(0, 0, 0, .5)',
}
},
]}
/>
)}
......@@ -165,7 +182,9 @@ const Chart = ({ metrics, compareChart }) => {
indexBy="metric"
margin={{ top: 50, right: 130, bottom: 50, left: 60 }}
padding={0.3}
groupMode={normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'}
groupMode={
normalizedData.seriesKeys.length > 1 ? 'grouped' : 'stacked'
}
axisTop={null}
axisRight={null}
axisBottom={{
......@@ -210,31 +229,35 @@ const Chart = ({ metrics, compareChart }) => {
enableDotLabel={true}
dotLabel="value"
dotLabelYOffset={-12}
legends={normalizedData.seriesKeys.length > 1 ? [
{
anchor: 'right',
direction: 'column',
translateX: 50,
translateY: 0,
itemWidth: 120,
itemHeight: 20,
itemTextColor: '#999',
symbolSize: 12,
symbolShape: 'circle',
effects: [
{
on: 'hover',
style: {
itemTextColor: '#000'
}
}
]
}
] : []}
legends={
normalizedData.seriesKeys.length > 0
? [
{
anchor: 'top-right',
direction: 'column',
translateX: -200,
translateY: 0,
itemWidth: 120,
itemHeight: 20,
itemTextColor: '#999',
symbolSize: 12,
symbolShape: 'circle',
effects: [
{
on: 'hover',
style: {
itemTextColor: '#000',
},
},
],
},
]
: []
}
/>
)}
</div>
</>
</Box>
);
};
......
......@@ -9,7 +9,13 @@ function parseJSON(data) {
}
}
export default function ViewPlotModal({ open, onClose, data, jobId, compareChart= false}) {
export default function ViewPlotModal({
open,
onClose,
data,
jobId,
compareChart = false,
}) {
if (!jobId) {
return <></>;
}
......@@ -29,9 +35,6 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart
height: '100%',
}}
>
<Typography level="h4" mb={2}>
Chart
</Typography>
<Box
sx={{
width: '100%',
......@@ -42,9 +45,12 @@ export default function ViewPlotModal({ open, onClose, data, jobId, compareChart
borderRadius: '8px',
boxShadow: 1,
p: 2,
display: 'flex',
flexDirection: 'column',
justifyContent: 'center',
}}
>
<Chart metrics={parseJSON(data)} compareChart={compareChart} />
<Chart metrics={parseJSON(data)} compareChart={compareChart} />
</Box>
</Box>
</ModalDialog>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment