diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx index 303e3e99194ee8a3286ab6a2f2042e605c66da4b..4800dbc25980b4fcf1aa9adc6beb2651027442b0 100644 --- a/src/renderer/components/Experiment/Eval/Chart.tsx +++ b/src/renderer/components/Experiment/Eval/Chart.tsx @@ -1,10 +1,11 @@ +//// filepath: /Users/deep.gandhi/transformerlab-repos/transformerlab-app/src/renderer/components/Experiment/Eval/Chart.tsx import React, { useState } from 'react'; import { ResponsiveLine } from '@nivo/line'; import { ResponsiveBar } from '@nivo/bar'; import { ResponsiveRadar } from '@nivo/radar'; import { Select, Option, FormControl } from '@mui/joy'; -const Chart = ({ metrics }) => { +const Chart = ({ metrics, compareChart }) => { const [chartType, setChartType] = useState('bar'); const handleChartTypeChange = (event, newValue) => { @@ -15,29 +16,71 @@ const Chart = ({ metrics }) => { return <div>No metrics available</div>; } - console.log(metrics); + console.log("METRICS", metrics); - const data = metrics.map((metric) => ({ - id: metric.type, - value: metric.score, - })); + let barData, lineData, radarData; - const lineData = [ - { - id: 'metrics', - data: metrics.map((metric) => ({ x: metric.type, y: metric.score })), - }, - ]; + if (compareChart) { + // For compare mode, multiple evaluators/jobs become separate series. + // Use evaluator-job as series key. + const seriesKeys = Array.from( + new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)) + ); - const barData = metrics.map((metric) => ({ - type: metric.type, - score: metric.score, - })); + // For Bar chart: group by metric type. + const barDataMap = {}; + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + if (!barDataMap[type]) { + barDataMap[type] = { type }; + } + barDataMap[type][seriesKey] = score; + }); + barData = Object.values(barDataMap); - const radarData = metrics.map((metric) => ({ - metric: metric.type, - score: metric.score, - })); + // For Line chart: each series is an evaluator-job. + const seriesDataMap = {}; + seriesKeys.forEach(series => { + seriesDataMap[series] = { id: series, data: [] }; + }); + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + seriesDataMap[seriesKey].data.push({ x: type, y: score }); + }); + lineData = Object.values(seriesDataMap); + + // For Radar chart: similar to bar, but keys represent evaluator-job score. + const radarDataMap = {}; + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + if (!radarDataMap[type]) { + radarDataMap[type] = { metric: type }; + } + radarDataMap[type][seriesKey] = score; + }); + radarData = Object.values(radarDataMap); + } else { + // Original logic: assume a single series. + barData = metrics.map(metric => ({ + type: metric.type, + score: metric.score, + })); + + lineData = [ + { + id: 'metrics', + data: metrics.map(metric => ({ x: metric.type, y: metric.score })), + } + ]; + + radarData = metrics.map(metric => ({ + metric: metric.type, + score: metric.score, + })); + } return ( <> @@ -84,10 +127,15 @@ const Chart = ({ metrics }) => { {chartType === 'bar' && ( <ResponsiveBar data={barData} - keys={['score']} + keys={ + compareChart + ? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`))) + : ['score'] + } indexBy="type" margin={{ top: 50, right: 130, bottom: 50, left: 60 }} padding={0.3} + groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here. axisTop={null} axisRight={null} axisBottom={{ @@ -107,32 +155,94 @@ const Chart = ({ metrics }) => { legendOffset: -40, }} colors={{ scheme: 'nivo' }} - colorBy="indexValue" + colorBy="id" animate={false} /> )} - {chartType === 'radar' && ( +{chartType === 'radar' && compareChart && ( + <div style={{ position: 'relative', height: '100%', width: '100%' }}> + {Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))).map((series, index) => { + // Filter radar data for this specific evaluator-job combination + const seriesRadarData = metrics + .filter(m => `${m.evaluator}-${m.job_id}` === series) + .map(metric => ({ + metric: metric.type, + score: metric.score + })); + + + return ( + <div key={series} style={{ position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 }}> <ResponsiveRadar - data={radarData} + data={seriesRadarData} keys={['score']} indexBy="metric" margin={{ top: 70, right: 80, bottom: 40, left: 80 }} gridShape="circular" gridLabelOffset={36} dotSize={10} - dotColor={{ from: 'color', modifiers: [] }} + dotColor={{ from: 'color' }} dotBorderWidth={2} - dotBorderColor={{ from: 'color', modifiers: [] }} - colors={{ scheme: 'nivo' }} - fillOpacity={0.25} - blendMode="multiply" + dotBorderColor={{ from: 'color' }} + colors={[`hsl(${index * 30 + 60}, 70%, 50%)`]} // Different color for each series + fillOpacity={0.2} + blendMode="normal" animate={true} motionConfig="wobbly" enableDotLabel={true} dotLabel="value" dotLabelYOffset={-12} + legends={[ + { + anchor: 'top-right', + direction: 'column', + translateX: 0, + translateY: -40, + itemWidth: 80, + itemHeight: 20, + itemTextColor: '#999', + symbolSize: 12, + symbolShape: 'circle', + effects: [ + { + on: 'hover', + style: { + itemTextColor: '#000' + } + } + ], + data: [{ id: series, label: series, color: `hsl(${index * 30 + 60}, 70%, 50%)` }] + } + ]} /> - )} + </div> + ); + })} + </div> +)} + +{chartType === 'radar' && !compareChart && ( + <ResponsiveRadar + data={radarData} + keys={['score']} + indexBy="metric" + margin={{ top: 70, right: 80, bottom: 40, left: 80 }} + gridShape="circular" + gridLabelOffset={36} + dotSize={10} + dotColor={{ from: 'color', modifiers: [] }} + dotBorderWidth={2} + dotBorderColor={{ from: 'color', modifiers: [] }} + colors={{ scheme: 'nivo' }} + fillOpacity={0.25} + blendMode="multiply" + animate={true} + motionConfig="wobbly" + enableDotLabel={true} + dotLabel="value" + dotLabelYOffset={-12} + /> +)} </div> </> ); diff --git a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx index 481c36d42ec3b45b7fc36a9a0d3047b7bee463b4..7abddcb1b73f2b77e784731ccd80ef063ab3f48e 100644 --- a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx +++ b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx @@ -91,6 +91,57 @@ function RenderScore({ score }) { )); } +function transformMetrics( + data: Array<{ + metric_name: string; + score: number; + evaluator_name: string; + job_id: string; + [key: string]: any; + }> +) { + const grouped: { + [key: string]: { + evaluator_name: string; + job_id: string; + type: string; + sum: number; + count: number; + }; + } = {}; + + data.forEach((entry) => { + // Extract only the fields we care about. + let { metric_name, score, evaluator_name, job_id } = entry; + if (!metric_name || score === undefined || score === null || !evaluator_name || !job_id) { + return; + } + + // Use a combined key to group only entries that share evaluator_name, job_id AND metric_name. + const key = `${evaluator_name}|${job_id}|${metric_name}`; + if (grouped[key]) { + grouped[key].sum += score; + grouped[key].count += 1; + } else { + grouped[key] = { + evaluator_name, + job_id, + type: metric_name, + sum: score, + count: 1, + }; + } + }); + + // Generate deduplicated list with averaged scores rounded to 5 decimals. + return Object.values(grouped).map((item) => ({ + evaluator: item.evaluator_name, + job_id: item.job_id, + type: item.type, + score: Number((item.sum / item.count).toFixed(5)), + })); +} + const EvalJobsTable = () => { const [selected, setSelected] = useState<readonly string[]>([]); @@ -100,6 +151,7 @@ const EvalJobsTable = () => { const [currentJobId, setCurrentJobId] = useState(''); const [currentData, setCurrentData] = useState(''); const [chart, setChart] = useState(true); + const [compareChart, setCompareChart] = useState(false); const [currentTensorboardForModal, setCurrentTensorboardForModal] = useState(-1); const [fileNameForDetailedReport, setFileNameForDetailedReport] = useState(''); @@ -130,10 +182,13 @@ const EvalJobsTable = () => { throw new Error('Network response was not ok'); } const data = await response.json(); - console.log('data', data); - setCurrentData(JSON.stringify(data)); + const transformedData = transformMetrics(JSON.parse(data)); + + setCurrentData(JSON.stringify(transformedData)); + // setCurrentData(JSON.stringify(data)); setOpenPlotModal(true); - setChart(false); + setChart(true); + setCompareChart(true); setCurrentJobId('-1'); } catch (error) { console.error('Failed to fetch combined reports:', error); @@ -146,9 +201,12 @@ const EvalJobsTable = () => { setOpenCSVModal(true); }; - const handleOpenPlotModal = (score) => { + const handleOpenPlotModal = (jobId, score) => { setCurrentData(score); setOpenPlotModal(true); + setChart(true); + setCompareChart(false); + setCurrentJobId(jobId); }; useEffect(() => { @@ -169,6 +227,7 @@ const EvalJobsTable = () => { data={currentData} jobId={currentJobId} chart={chart} + compareChart={compareChart} /> <ViewOutputModalStreaming jobId={viewOutputFromJob} diff --git a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx index 2623586cb05faa90589bef7572282c6bc89cbc5b..1d98e7433c0f22ea9a99e640d23417feb4adb8ed 100644 --- a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx @@ -1,7 +1,5 @@ -import React from 'react'; import { Modal, ModalDialog, ModalClose, Box, Typography } from '@mui/joy'; import Chart from './Chart'; -import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; function parseJSON(data) { try { @@ -11,7 +9,7 @@ function parseJSON(data) { } } -export default function ViewPlotModal({ open, onClose, data, jobId, chart = true}) { +export default function ViewPlotModal({ open, onClose, data, jobId, chart = true, compareChart= false}) { if (!jobId) { return <></>; } @@ -47,7 +45,7 @@ export default function ViewPlotModal({ open, onClose, data, jobId, chart = true }} > {chart ? ( - <Chart metrics={parseJSON(data)} /> + <Chart metrics={parseJSON(data)} compareChart={compareChart} /> ) : ( <div>{JSON.stringify(parseJSON(data))}</div> )}