diff --git a/src/renderer/components/Experiment/Eval/Chart.tsx b/src/renderer/components/Experiment/Eval/Chart.tsx index 303e3e99194ee8a3286ab6a2f2042e605c66da4b..baea21212da7728457629a6455fde8b3d4ab4955 100644 --- a/src/renderer/components/Experiment/Eval/Chart.tsx +++ b/src/renderer/components/Experiment/Eval/Chart.tsx @@ -1,10 +1,11 @@ +//// filepath: /Users/deep.gandhi/transformerlab-repos/transformerlab-app/src/renderer/components/Experiment/Eval/Chart.tsx import React, { useState } from 'react'; import { ResponsiveLine } from '@nivo/line'; import { ResponsiveBar } from '@nivo/bar'; import { ResponsiveRadar } from '@nivo/radar'; import { Select, Option, FormControl } from '@mui/joy'; -const Chart = ({ metrics }) => { +const Chart = ({ metrics, compareChart }) => { const [chartType, setChartType] = useState('bar'); const handleChartTypeChange = (event, newValue) => { @@ -15,29 +16,70 @@ const Chart = ({ metrics }) => { return <div>No metrics available</div>; } - console.log(metrics); - const data = metrics.map((metric) => ({ - id: metric.type, - value: metric.score, - })); + let barData, lineData, radarData; - const lineData = [ - { - id: 'metrics', - data: metrics.map((metric) => ({ x: metric.type, y: metric.score })), - }, - ]; + if (compareChart) { + // For compare mode, multiple evaluators/jobs become separate series. + // Use evaluator-job as series key. + const seriesKeys = Array.from( + new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`)) + ); - const barData = metrics.map((metric) => ({ - type: metric.type, - score: metric.score, - })); + // For Bar chart: group by metric type. + const barDataMap = {}; + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + if (!barDataMap[type]) { + barDataMap[type] = { type }; + } + barDataMap[type][seriesKey] = score; + }); + barData = Object.values(barDataMap); - const radarData = metrics.map((metric) => ({ - metric: metric.type, - score: metric.score, - })); + // For Line chart: each series is an evaluator-job. + const seriesDataMap = {}; + seriesKeys.forEach(series => { + seriesDataMap[series] = { id: series, data: [] }; + }); + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + seriesDataMap[seriesKey].data.push({ x: type, y: score }); + }); + lineData = Object.values(seriesDataMap); + + // For Radar chart: similar to bar, but keys represent evaluator-job score. + const radarDataMap = {}; + metrics.forEach(metric => { + const { type, evaluator, job_id, score } = metric; + const seriesKey = `${evaluator}-${job_id}`; + if (!radarDataMap[type]) { + radarDataMap[type] = { metric: type }; + } + radarDataMap[type][seriesKey] = score; + }); + radarData = Object.values(radarDataMap); + } else { + // Original logic: assume a single series. + barData = metrics.map(metric => ({ + type: metric.type, + score: metric.score, + })); + + lineData = [ + { + id: 'metrics', + data: metrics.map(metric => ({ x: metric.type, y: metric.score })), + } + ]; + + radarData = metrics.map(metric => ({ + metric: metric.type, + score: metric.score, + })); + } return ( <> @@ -84,10 +126,15 @@ const Chart = ({ metrics }) => { {chartType === 'bar' && ( <ResponsiveBar data={barData} - keys={['score']} + keys={ + compareChart + ? Array.from(new Set(metrics.map((m) => `${m.evaluator}-${m.job_id}`))) + : ['score'] + } indexBy="type" margin={{ top: 50, right: 130, bottom: 50, left: 60 }} padding={0.3} + groupMode={compareChart ? 'grouped' : 'stacked'} // Added groupMode here. axisTop={null} axisRight={null} axisBottom={{ @@ -107,32 +154,94 @@ const Chart = ({ metrics }) => { legendOffset: -40, }} colors={{ scheme: 'nivo' }} - colorBy="indexValue" + colorBy="id" animate={false} /> )} - {chartType === 'radar' && ( +{chartType === 'radar' && compareChart && ( + <div style={{ position: 'relative', height: '100%', width: '100%' }}> + {Array.from(new Set(metrics.map(m => `${m.evaluator}-${m.job_id}`))).map((series, index) => { + // Filter radar data for this specific evaluator-job combination + const seriesRadarData = metrics + .filter(m => `${m.evaluator}-${m.job_id}` === series) + .map(metric => ({ + metric: metric.type, + score: metric.score + })); + + + return ( + <div key={series} style={{ position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 }}> <ResponsiveRadar - data={radarData} + data={seriesRadarData} keys={['score']} indexBy="metric" margin={{ top: 70, right: 80, bottom: 40, left: 80 }} gridShape="circular" gridLabelOffset={36} dotSize={10} - dotColor={{ from: 'color', modifiers: [] }} + dotColor={{ from: 'color' }} dotBorderWidth={2} - dotBorderColor={{ from: 'color', modifiers: [] }} - colors={{ scheme: 'nivo' }} - fillOpacity={0.25} - blendMode="multiply" + dotBorderColor={{ from: 'color' }} + colors={[`hsl(${index * 30 + 60}, 70%, 50%)`]} // Different color for each series + fillOpacity={0.2} + blendMode="normal" animate={true} motionConfig="wobbly" enableDotLabel={true} dotLabel="value" dotLabelYOffset={-12} + legends={[ + { + anchor: 'top-right', + direction: 'column', + translateX: 0, + translateY: -40, + itemWidth: 80, + itemHeight: 20, + itemTextColor: '#999', + symbolSize: 12, + symbolShape: 'circle', + effects: [ + { + on: 'hover', + style: { + itemTextColor: '#000' + } + } + ], + data: [{ id: series, label: series, color: `hsl(${index * 30 + 60}, 70%, 50%)` }] + } + ]} /> - )} + </div> + ); + })} + </div> +)} + +{chartType === 'radar' && !compareChart && ( + <ResponsiveRadar + data={radarData} + keys={['score']} + indexBy="metric" + margin={{ top: 70, right: 80, bottom: 40, left: 80 }} + gridShape="circular" + gridLabelOffset={36} + dotSize={10} + dotColor={{ from: 'color', modifiers: [] }} + dotBorderWidth={2} + dotBorderColor={{ from: 'color', modifiers: [] }} + colors={{ scheme: 'nivo' }} + fillOpacity={0.25} + blendMode="multiply" + animate={true} + motionConfig="wobbly" + enableDotLabel={true} + dotLabel="value" + dotLabelYOffset={-12} + /> +)} </div> </> ); diff --git a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx index 481c36d42ec3b45b7fc36a9a0d3047b7bee463b4..9734bf3d677a15b335030ef6a21c1dfd82f5b747 100644 --- a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx +++ b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx @@ -91,15 +91,101 @@ function RenderScore({ score }) { )); } +function transformMetrics( + data: Array<{ + test_case_id: string; + metric_name: string; + score: number; + evaluator_name: string; + job_id: string; + [key: string]: any; + }>, + type: 'summary' | 'detailed' = 'summary' +) { + if (type === 'summary') { + const grouped: { + [key: string]: { + evaluator_name: string; + job_id: string; + type: string; + sum: number; + count: number; + }; + } = {}; + + data.forEach((entry) => { + // Extract only the fields we care about. + let { metric_name, score, evaluator_name, job_id } = entry; + if (!metric_name || score === undefined || score === null || !evaluator_name || !job_id) { + return; + } + + // Use a combined key to group only entries that share evaluator_name, job_id AND metric_name. + const key = `${evaluator_name}|${job_id}|${metric_name}`; + if (grouped[key]) { + grouped[key].sum += score; + grouped[key].count += 1; + } else { + grouped[key] = { + evaluator_name, + job_id, + type: metric_name, + sum: score, + count: 1, + }; + } + }); + + // Generate deduplicated list with averaged scores rounded to 5 decimals. + return Object.values(grouped).map((item) => ({ + evaluator: item.evaluator_name, + job_id: item.job_id, + type: item.type, + score: Number((item.sum / item.count).toFixed(5)), + })); + } else if (type === 'detailed') { + // For detailed output we are not averaging. + // Expected header sequence: test_case_id, metric_name, job_id, evaluator_name, metric_name, score, ...extra + // Determine extra keys from the entry (excluding core ones). + const extraKeysSet = new Set<string>(); + data.forEach((entry) => { + Object.keys(entry).forEach((k) => { + if (!['test_case_id', 'metric_name', 'job_id', 'evaluator_name', 'score'].includes(k)) { + extraKeysSet.add(k); + } + }); + }); + const extraKeys = Array.from(extraKeysSet).sort(); + + const header = ['test_case_id', 'metric_name', 'job_id', 'evaluator_name', 'metric_name', 'score', ...extraKeys]; + + const body = data.map((entry) => { + const extraValues = extraKeys.map((key) => entry[key]); + return [ + entry.test_case_id, // using test_case_id instead of job_id + entry.metric_name, + entry.job_id, + entry.evaluator_name, + entry.metric_name, + entry.score, + ...extraValues, + ]; + }); + + return { header, body }; + } +} + const EvalJobsTable = () => { const [selected, setSelected] = useState<readonly string[]>([]); const [viewOutputFromJob, setViewOutputFromJob] = useState(-1); const [openCSVModal, setOpenCSVModal] = useState(false); + const [compareData, setCompareData] = useState(null); const [openPlotModal, setOpenPlotModal] = useState(false); const [currentJobId, setCurrentJobId] = useState(''); const [currentData, setCurrentData] = useState(''); - const [chart, setChart] = useState(true); + const [compareChart, setCompareChart] = useState(false); const [currentTensorboardForModal, setCurrentTensorboardForModal] = useState(-1); const [fileNameForDetailedReport, setFileNameForDetailedReport] = useState(''); @@ -121,34 +207,46 @@ const EvalJobsTable = () => { fallbackData: [], }); - const handleCombinedReports = async () => { - try { - const jobIdsParam = selected.join(','); - const compareEvalsUrl = chatAPI.Endpoints.Charts.CompareEvals(jobIdsParam); - const response = await fetch(compareEvalsUrl, { method: 'GET' }); - if (!response.ok) { - throw new Error('Network response was not ok'); - } - const data = await response.json(); - console.log('data', data); - setCurrentData(JSON.stringify(data)); + const handleCombinedReports = async (comparisonType: 'summary' | 'detailed' = 'summary') => { + try { + const jobIdsParam = selected.join(','); + const compareEvalsUrl = chatAPI.Endpoints.Charts.CompareEvals(jobIdsParam); + const response = await fetch(compareEvalsUrl, { method: 'GET' }); + if (!response.ok) { + throw new Error('Network response was not ok'); + } + const data = await response.json(); + if (comparisonType === 'summary') { + const transformedData = transformMetrics(JSON.parse(data), "summary"); + + setCurrentData(JSON.stringify(transformedData)); setOpenPlotModal(true); - setChart(false); + setCompareChart(true); setCurrentJobId('-1'); - } catch (error) { - console.error('Failed to fetch combined reports:', error); + } else if (comparisonType === 'detailed') { + const transformedData = transformMetrics(JSON.parse(data), "detailed"); + + setCompareData(transformedData); + handleOpenCSVModal('-1'); + } - }; + } catch (error) { + console.error('Failed to fetch combined reports:', error); + } + }; const handleOpenCSVModal = (jobId) => { setCurrentJobId(jobId); setOpenCSVModal(true); + }; - const handleOpenPlotModal = (score) => { + const handleOpenPlotModal = (jobId, score) => { setCurrentData(score); setOpenPlotModal(true); + setCompareChart(false); + setCurrentJobId(jobId); }; useEffect(() => { @@ -162,13 +260,14 @@ const EvalJobsTable = () => { onClose={() => setOpenCSVModal(false)} jobId={currentJobId} fetchCSV={fetchCSV} + compareData={compareData} /> <ViewPlotModal open={openPlotModal} onClose={() => setOpenPlotModal(false)} data={currentData} jobId={currentJobId} - chart={chart} + compareChart={compareChart} /> <ViewOutputModalStreaming jobId={viewOutputFromJob} @@ -189,18 +288,24 @@ const EvalJobsTable = () => { > <Typography level="h3">Executions</Typography> {selected.length > 1 && ( - <Typography - level="body-sm" - startDecorator={<ChartColumnIncreasingIcon size="20px" />} - // Uncomment this line to enable the combined reports feature - onClick={handleCombinedReports} - // onClick={() => { - // alert('this feature coming soon'); - // }} - sx={{ cursor: 'pointer' }} - > - <>Compare Selected Evals</> - </Typography> + <Box sx={{ display: 'flex', gap: 2 }}> + <Typography + level="body-sm" + startDecorator={<ChartColumnIncreasingIcon size="20px" />} + onClick={() => handleCombinedReports('summary')} + sx={{ cursor: 'pointer' }} + > + Compare Selected Evals + </Typography> + <Typography + level="body-sm" + startDecorator={<Grid3X3Icon size="20px" />} + onClick={() => handleCombinedReports('detailed')} + sx={{ cursor: 'pointer' }} + > + Detailed Comparison + </Typography> + </Box> )} </Box> diff --git a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx index e933f1c9f63d27f72699b673da332baefce2f65e..b6bccb7e11f3b642768301000260678a371c8226 100644 --- a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx @@ -26,7 +26,7 @@ function heatedColor(value) { // This function formats the eval data to combine rows that have the same name // based on the first column -function formatEvalData(data) { +function formatEvalData(data, compareEvals = false) { let header = data?.header; let body = data?.body; const formattedData: any[] = []; @@ -41,36 +41,77 @@ function formatEvalData(data) { return data; } - // remove the header named "metric_name" - if (header[1] === 'metric_name') { - header = header.slice(1); - } - const seen = new Set(); - body.forEach((row) => { - if (!seen.has(row[0])) { - seen.add(row[0]); - const newRow = [row[0]]; - newRow.push({ [row[1]]: row[2] }); - // now push the rest of the columns: - for (let i = 3; i < row.length; i++) { - newRow.push(row[i]); - } - formattedData.push(newRow); - } else { - const index = formattedData.findIndex((r) => r[0] === row[0]); - let newScore = []; - // if formattedData[index][1] is an array, then we need to push to it - if (Array.isArray(formattedData[index][1])) { - newScore = formattedData[index][1]; + + if (compareEvals) { + // Ensure the header has at least the expected columns: + // test_case_id, metric_name, job_id, evaluator_name, metric_name, score, ... + if (header.length < 6) { + return data; + } + // Remove columns: drop the first metric_name (index 1), job_id (index 2), evaluator_name (index 3) + // and the metric_name/score pair (indices 4 and 5) will be combined + // New header: test_case_id, combined_scores, then any extra columns (starting from index 6) + header = [header[0], 'score', ...header.slice(6)]; + + body.forEach((row) => { + // Sanity check row length + if (row.length < 6) return; + const testCaseId = row[0]; + const jobId = row[2]; + const evaluatorName = row[3]; + const metricName = row[4]; + const scoreVal = row[5]; + const combinedScore = { [`${evaluatorName}-${jobId}-${metricName}`]: scoreVal }; + + // Append additional columns after the 6th column, if any + const extraColumns = row.slice(6); + + if (!seen.has(testCaseId)) { + seen.add(testCaseId); + // newRow: [test_case_id, combinedScore, extra columns...] + formattedData.push([testCaseId, combinedScore, ...extraColumns]); } else { - newScore.push(formattedData[index][1]); + const index = formattedData.findIndex((r) => r[0] === testCaseId); + let newScore = []; + if (Array.isArray(formattedData[index][1])) { + newScore = formattedData[index][1]; + } else { + newScore.push(formattedData[index][1]); + } + + newScore.push(combinedScore); + formattedData[index][1] = newScore; } - newScore.push({ [row[1]]: row[2] }); - formattedData[index][1] = newScore; + }); + } else { + // original processing: remove "metric_name" if it is header[1] + if (header[1] === 'metric_name') { + header = header.slice(1); } - }); - + body.forEach((row) => { + if (!seen.has(row[0])) { + seen.add(row[0]); + const newRow = [row[0]]; + newRow.push({ [row[1]]: row[2] }); + for (let i = 3; i < row.length; i++) { + newRow.push(row[i]); + } + console.log("NEW ROW", newRow); + formattedData.push(newRow); + } else { + const index = formattedData.findIndex((r) => r[0] === row[0]); + let newScore = []; + if (Array.isArray(formattedData[index][1])) { + newScore = formattedData[index][1]; + } else { + newScore.push(formattedData[index][1]); + } + newScore.push({ [row[1]]: row[2] }); + formattedData[index][1] = newScore; + } + }); + } return { header: header, body: formattedData }; } @@ -125,10 +166,36 @@ function formatScore(score) { } } -const ViewCSVModal = ({ open, onClose, jobId, fetchCSV }) => { +const convertReportToCSV = (report: { header: any[]; body: any[] }) => { + if (!report?.header || !report?.body) return ''; + const csvRows = []; + csvRows.push(report.header.join(',')); + report.body.forEach((row) => { + const csvRow = row + .map((cell) => { + let cellText = ''; + if (typeof cell === 'object') { + // Convert objects to a JSON string and escape inner quotes + cellText = JSON.stringify(cell).replace(/"/g, '""'); + } else { + cellText = cell; + } + return `"${cellText}"`; + }) + .join(','); + csvRows.push(csvRow); + }); + return csvRows.join('\n'); +}; + +const ViewCSVModal = ({ open, onClose, jobId, fetchCSV, compareData = null }) => { const [report, setReport] = useState({}); + useEffect(() => { + + if (!compareData) { + if (open && jobId) { fetchCSV(jobId).then((data) => { try { @@ -138,9 +205,22 @@ const ViewCSVModal = ({ open, onClose, jobId, fetchCSV }) => { } }); } + + } else { + try { + setReport(formatEvalData(compareData, true)); + } catch (e) { + setReport({ header: ['Error'], body: [[compareData]] }); + } + + } + }, [open, jobId, fetchCSV]); + const handleDownload = async () => { + + if (!compareData) { const response = await fetch( chatAPI.Endpoints.Experiment.GetAdditionalDetails(jobId, 'download') ); @@ -153,7 +233,20 @@ const ViewCSVModal = ({ open, onClose, jobId, fetchCSV }) => { link.click(); document.body.removeChild(link); URL.revokeObjectURL(url); - }; + } else { + const csvContent = convertReportToCSV(report); + const blob = new Blob([csvContent], { type: 'text/csv' }); + const url = URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = url; + link.download = `detailed_report.csv`; // Adjust extension if necessary + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + URL.revokeObjectURL(url); + } + +}; return ( <Modal open={open} onClose={onClose}> diff --git a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx index 2623586cb05faa90589bef7572282c6bc89cbc5b..91583c70846f8a3c253fe923231212e0a69ab93b 100644 --- a/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewPlotModal.tsx @@ -1,7 +1,5 @@ -import React from 'react'; import { Modal, ModalDialog, ModalClose, Box, Typography } from '@mui/joy'; import Chart from './Chart'; -import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; function parseJSON(data) { try { @@ -11,7 +9,7 @@ function parseJSON(data) { } } -export default function ViewPlotModal({ open, onClose, data, jobId, chart = true}) { +export default function ViewPlotModal({ open, onClose, data, jobId, compareChart= false}) { if (!jobId) { return <></>; } @@ -46,11 +44,7 @@ export default function ViewPlotModal({ open, onClose, data, jobId, chart = true p: 2, }} > - {chart ? ( - <Chart metrics={parseJSON(data)} /> - ) : ( - <div>{JSON.stringify(parseJSON(data))}</div> - )} + <Chart metrics={parseJSON(data)} compareChart={compareChart} /> </Box> </Box> </ModalDialog>