diff --git a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx index 7abddcb1b73f2b77e784731ccd80ef063ab3f48e..21156cb0d5eedd9a33da196c760b13623656e088 100644 --- a/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx +++ b/src/renderer/components/Experiment/Eval/EvalJobsTable.tsx @@ -93,53 +93,87 @@ function RenderScore({ score }) { function transformMetrics( data: Array<{ + test_case_id: string; metric_name: string; score: number; evaluator_name: string; job_id: string; [key: string]: any; - }> + }>, + type: 'summary' | 'detailed' = 'summary' ) { - const grouped: { - [key: string]: { - evaluator_name: string; - job_id: string; - type: string; - sum: number; - count: number; - }; - } = {}; + if (type === 'summary') { + const grouped: { + [key: string]: { + evaluator_name: string; + job_id: string; + type: string; + sum: number; + count: number; + }; + } = {}; - data.forEach((entry) => { - // Extract only the fields we care about. - let { metric_name, score, evaluator_name, job_id } = entry; - if (!metric_name || score === undefined || score === null || !evaluator_name || !job_id) { - return; - } + data.forEach((entry) => { + // Extract only the fields we care about. + let { metric_name, score, evaluator_name, job_id } = entry; + if (!metric_name || score === undefined || score === null || !evaluator_name || !job_id) { + return; + } - // Use a combined key to group only entries that share evaluator_name, job_id AND metric_name. - const key = `${evaluator_name}|${job_id}|${metric_name}`; - if (grouped[key]) { - grouped[key].sum += score; - grouped[key].count += 1; - } else { - grouped[key] = { - evaluator_name, - job_id, - type: metric_name, - sum: score, - count: 1, - }; - } - }); + // Use a combined key to group only entries that share evaluator_name, job_id AND metric_name. + const key = `${evaluator_name}|${job_id}|${metric_name}`; + if (grouped[key]) { + grouped[key].sum += score; + grouped[key].count += 1; + } else { + grouped[key] = { + evaluator_name, + job_id, + type: metric_name, + sum: score, + count: 1, + }; + } + }); + + // Generate deduplicated list with averaged scores rounded to 5 decimals. + return Object.values(grouped).map((item) => ({ + evaluator: item.evaluator_name, + job_id: item.job_id, + type: item.type, + score: Number((item.sum / item.count).toFixed(5)), + })); + } else if (type === 'detailed') { + // For detailed output we are not averaging. + // Expected header sequence: test_case_id, metric_name, job_id, evaluator_name, metric_name, score, ...extra + // Determine extra keys from the entry (excluding core ones). + const extraKeysSet = new Set<string>(); + data.forEach((entry) => { + Object.keys(entry).forEach((k) => { + if (!['test_case_id', 'metric_name', 'job_id', 'evaluator_name', 'score'].includes(k)) { + extraKeysSet.add(k); + } + }); + }); + const extraKeys = Array.from(extraKeysSet).sort(); + + const header = ['test_case_id', 'metric_name', 'job_id', 'evaluator_name', 'metric_name', 'score', ...extraKeys]; - // Generate deduplicated list with averaged scores rounded to 5 decimals. - return Object.values(grouped).map((item) => ({ - evaluator: item.evaluator_name, - job_id: item.job_id, - type: item.type, - score: Number((item.sum / item.count).toFixed(5)), - })); + const body = data.map((entry) => { + const extraValues = extraKeys.map((key) => entry[key]); + return [ + entry.test_case_id, // using test_case_id instead of job_id + entry.metric_name, + entry.job_id, + entry.evaluator_name, + entry.metric_name, + entry.score, + ...extraValues, + ]; + }); + + return { header, body }; + } } @@ -147,6 +181,7 @@ const EvalJobsTable = () => { const [selected, setSelected] = useState<readonly string[]>([]); const [viewOutputFromJob, setViewOutputFromJob] = useState(-1); const [openCSVModal, setOpenCSVModal] = useState(false); + const [compareData, setCompareData] = useState(null); const [openPlotModal, setOpenPlotModal] = useState(false); const [currentJobId, setCurrentJobId] = useState(''); const [currentData, setCurrentData] = useState(''); @@ -173,32 +208,40 @@ const EvalJobsTable = () => { fallbackData: [], }); - const handleCombinedReports = async () => { - try { - const jobIdsParam = selected.join(','); - const compareEvalsUrl = chatAPI.Endpoints.Charts.CompareEvals(jobIdsParam); - const response = await fetch(compareEvalsUrl, { method: 'GET' }); - if (!response.ok) { - throw new Error('Network response was not ok'); - } - const data = await response.json(); - const transformedData = transformMetrics(JSON.parse(data)); + const handleCombinedReports = async (comparisonType: 'summary' | 'detailed' = 'summary') => { + try { + const jobIdsParam = selected.join(','); + const compareEvalsUrl = chatAPI.Endpoints.Charts.CompareEvals(jobIdsParam); + const response = await fetch(compareEvalsUrl, { method: 'GET' }); + if (!response.ok) { + throw new Error('Network response was not ok'); + } + const data = await response.json(); + if (comparisonType === 'summary') { + const transformedData = transformMetrics(JSON.parse(data), "summary"); setCurrentData(JSON.stringify(transformedData)); - // setCurrentData(JSON.stringify(data)); setOpenPlotModal(true); setChart(true); setCompareChart(true); setCurrentJobId('-1'); - } catch (error) { - console.error('Failed to fetch combined reports:', error); + } else if (comparisonType === 'detailed') { + const transformedData = transformMetrics(JSON.parse(data), "detailed"); + + setCompareData(transformedData); + handleOpenCSVModal('-1'); + } - }; + } catch (error) { + console.error('Failed to fetch combined reports:', error); + } + }; const handleOpenCSVModal = (jobId) => { setCurrentJobId(jobId); setOpenCSVModal(true); + }; const handleOpenPlotModal = (jobId, score) => { @@ -220,6 +263,7 @@ const EvalJobsTable = () => { onClose={() => setOpenCSVModal(false)} jobId={currentJobId} fetchCSV={fetchCSV} + compareData={compareData} /> <ViewPlotModal open={openPlotModal} @@ -248,18 +292,24 @@ const EvalJobsTable = () => { > <Typography level="h3">Executions</Typography> {selected.length > 1 && ( - <Typography - level="body-sm" - startDecorator={<ChartColumnIncreasingIcon size="20px" />} - // Uncomment this line to enable the combined reports feature - onClick={handleCombinedReports} - // onClick={() => { - // alert('this feature coming soon'); - // }} - sx={{ cursor: 'pointer' }} - > - <>Compare Selected Evals</> - </Typography> + <Box sx={{ display: 'flex', gap: 2 }}> + <Typography + level="body-sm" + startDecorator={<ChartColumnIncreasingIcon size="20px" />} + onClick={() => handleCombinedReports('summary')} + sx={{ cursor: 'pointer' }} + > + Compare Selected Evals + </Typography> + <Typography + level="body-sm" + startDecorator={<Grid3X3Icon size="20px" />} + onClick={() => handleCombinedReports('detailed')} + sx={{ cursor: 'pointer' }} + > + Detailed Comparison + </Typography> + </Box> )} </Box> diff --git a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx index e933f1c9f63d27f72699b673da332baefce2f65e..bfd604f4bbcba7632eb306908c47fd22d5be616f 100644 --- a/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx +++ b/src/renderer/components/Experiment/Eval/ViewCSVModal.tsx @@ -26,7 +26,7 @@ function heatedColor(value) { // This function formats the eval data to combine rows that have the same name // based on the first column -function formatEvalData(data) { +function formatEvalData(data, compareEvals = false) { let header = data?.header; let body = data?.body; const formattedData: any[] = []; @@ -41,36 +41,77 @@ function formatEvalData(data) { return data; } - // remove the header named "metric_name" - if (header[1] === 'metric_name') { - header = header.slice(1); - } - const seen = new Set(); - body.forEach((row) => { - if (!seen.has(row[0])) { - seen.add(row[0]); - const newRow = [row[0]]; - newRow.push({ [row[1]]: row[2] }); - // now push the rest of the columns: - for (let i = 3; i < row.length; i++) { - newRow.push(row[i]); - } - formattedData.push(newRow); - } else { - const index = formattedData.findIndex((r) => r[0] === row[0]); - let newScore = []; - // if formattedData[index][1] is an array, then we need to push to it - if (Array.isArray(formattedData[index][1])) { - newScore = formattedData[index][1]; + + if (compareEvals) { + // Ensure the header has at least the expected columns: + // test_case_id, metric_name, job_id, evaluator_name, metric_name, score, ... + if (header.length < 6) { + return data; + } + // Remove columns: drop the first metric_name (index 1), job_id (index 2), evaluator_name (index 3) + // and the metric_name/score pair (indices 4 and 5) will be combined + // New header: test_case_id, combined_scores, then any extra columns (starting from index 6) + header = [header[0], 'score', ...header.slice(6)]; + + body.forEach((row) => { + // Sanity check row length + if (row.length < 6) return; + const testCaseId = row[0]; + const jobId = row[2]; + const evaluatorName = row[3]; + const metricName = row[4]; + const scoreVal = row[5]; + const combinedScore = { [`${evaluatorName}-${jobId}-${metricName}`]: scoreVal }; + + // Append additional columns after the 6th column, if any + const extraColumns = row.slice(6); + + if (!seen.has(testCaseId)) { + seen.add(testCaseId); + // newRow: [test_case_id, combinedScore, extra columns...] + formattedData.push([testCaseId, combinedScore, ...extraColumns]); } else { - newScore.push(formattedData[index][1]); + const index = formattedData.findIndex((r) => r[0] === testCaseId); + let newScore = []; + if (Array.isArray(formattedData[index][1])) { + newScore = formattedData[index][1]; + } else { + newScore.push(formattedData[index][1]); + } + + newScore.push(combinedScore); + formattedData[index][1] = newScore; } - newScore.push({ [row[1]]: row[2] }); - formattedData[index][1] = newScore; + }); + } else { + // original processing: remove "metric_name" if it is header[1] + if (header[1] === 'metric_name') { + header = header.slice(1); } - }); - + body.forEach((row) => { + if (!seen.has(row[0])) { + seen.add(row[0]); + const newRow = [row[0]]; + newRow.push({ [row[1]]: row[2] }); + for (let i = 3; i < row.length; i++) { + newRow.push(row[i]); + } + console.log("NEW ROW", newRow); + formattedData.push(newRow); + } else { + const index = formattedData.findIndex((r) => r[0] === row[0]); + let newScore = []; + if (Array.isArray(formattedData[index][1])) { + newScore = formattedData[index][1]; + } else { + newScore.push(formattedData[index][1]); + } + newScore.push({ [row[1]]: row[2] }); + formattedData[index][1] = newScore; + } + }); + } return { header: header, body: formattedData }; } @@ -125,10 +166,14 @@ function formatScore(score) { } } -const ViewCSVModal = ({ open, onClose, jobId, fetchCSV }) => { +const ViewCSVModal = ({ open, onClose, jobId, fetchCSV, compareData = null }) => { const [report, setReport] = useState({}); + useEffect(() => { + + if (!compareData) { + if (open && jobId) { fetchCSV(jobId).then((data) => { try { @@ -138,8 +183,20 @@ const ViewCSVModal = ({ open, onClose, jobId, fetchCSV }) => { } }); } + + } else { + try { + console.log('compareData', compareData); + setReport(formatEvalData(compareData, true)); + } catch (e) { + setReport({ header: ['Error'], body: [[compareData]] }); + } + + } + }, [open, jobId, fetchCSV]); + const handleDownload = async () => { const response = await fetch( chatAPI.Endpoints.Experiment.GetAdditionalDetails(jobId, 'download')