diff --git a/src/renderer/components/Experiment/Generate/Generate.tsx b/src/renderer/components/Experiment/Generate/Generate.tsx index 384583474c288a0793316722f98e848737b54ebd..cddd739140afb8bd3a8e798c230ce47b35da6817 100644 --- a/src/renderer/components/Experiment/Generate/Generate.tsx +++ b/src/renderer/components/Experiment/Generate/Generate.tsx @@ -46,9 +46,8 @@ export default function Generate({ experimentInfoMutate, }) { const [open, setOpen] = useState(false); - const [currentEvaluator, setCurrentEvaluator] = useState(''); const [currentPlugin, setCurrentPlugin] = useState(''); - const [currentEvalName, setCurrentEvalName] = useState(''); + const [currentGenerationName, setCurrentGenerationName] = useState(''); const { data: plugins, @@ -69,7 +68,7 @@ export default function Generate({ if (value) { // Use fetch to post the value to the server await fetch( - chatAPI.Endpoints.Experiment.SavePlugin(project, evalName, 'main.py'), + chatAPI.Endpoints.Experiment.SavePlugin(project, generationName, 'main.py'), { method: 'POST', body: value, @@ -104,12 +103,12 @@ export default function Generate({ open={open} onClose={() => { setOpen(false); - setCurrentEvalName(''); + setCurrentGenerationName(''); }} experimentInfo={experimentInfo} experimentInfoMutate={experimentInfoMutate} pluginId={currentPlugin} - currentEvalName={currentEvalName} + currentGenerationName={currentGenerationName} /> <Stack direction="row" @@ -165,7 +164,7 @@ export default function Generate({ experimentInfo={experimentInfo} experimentInfoMutate={experimentInfoMutate} setCurrentPlugin={setCurrentPlugin} - setCurrentEvalName={setCurrentEvalName} + setCurrentGenerationName={setCurrentGenerationName} setOpen={setOpen} /> </Sheet> diff --git a/src/renderer/components/Experiment/Generate/GenerateModal.tsx b/src/renderer/components/Experiment/Generate/GenerateModal.tsx index 126880a43decadc96a9c7e789bde7dd4266b1775..b38f361ee1fba991855910dd5a39052e1e7cf04c 100644 --- a/src/renderer/components/Experiment/Generate/GenerateModal.tsx +++ b/src/renderer/components/Experiment/Generate/GenerateModal.tsx @@ -55,7 +55,7 @@ export default function GenerateModal({ experimentInfo, experimentInfoMutate, pluginId, - currentEvalName, + currentGenerationName, }: { open: boolean; onClose: () => void; @@ -63,7 +63,7 @@ export default function GenerateModal({ experimentInfoMutate: () => void; template_id?: string; pluginId: string; - currentEvalName?: string; // Optional incase of new evaluation + currentGenerationName?: string; // Optional incase of new generation }) { // Store the current selected Dataset in this modal const [selectedDataset, setSelectedDataset] = useState(null); @@ -71,6 +71,7 @@ export default function GenerateModal({ const [hasDatasetKey, setHasDatasetKey] = useState(false); const [hasDocumentsKey, setHasDocumentsKey] = useState(false); const [hasContextKey, setHasContextKey] = useState(false); + const [selectedDocs, setSelectedDocs] = useState([]); const [nameInput, setNameInput] = useState(''); const [currentTab, setCurrentTab] = useState(0); const [contextInput, setContextInput] = useState(''); @@ -102,14 +103,10 @@ export default function GenerateModal({ return chatAPI.Endpoints.Dataset.Info(selectedDataset); }, fetcher); - // useEffect(() => { - // if (open) { // Reset the name input when the modal is opened - // setNameInput(generateFriendlyName()); - // }}, []); useEffect(() => { if (open) { - if (!currentEvalName || currentEvalName === '') { + if (!currentGenerationName || currentGenerationName === '') { setNameInput(generateFriendlyName()); } else { setNameInput(''); @@ -118,92 +115,91 @@ export default function GenerateModal({ }, [open]); useEffect(() => { + // EDIT GENERATION if (experimentInfo && pluginId) { - if (currentEvalName && currentEvalName !== '') { - const evaluationsStr = experimentInfo.config?.generations; - if (typeof evaluationsStr === 'string') { + if (currentGenerationName && currentGenerationName !== '') { + const generationsStr = experimentInfo.config?.generations; + setSelectedDocs([]); + if (typeof generationsStr === 'string') { try { - const evaluations = JSON.parse(evaluationsStr); - if (Array.isArray(evaluations)) { - const evalConfig = evaluations.find( - (evalItem: any) => - evalItem.name === currentEvalName && - evalItem.plugin === pluginId + const generations = JSON.parse(generationsStr); + if (Array.isArray(generations)) { + const generationConfig = generations.find( + (generationItem: any) => + generationItem.name === currentGenerationName && + generationItem.plugin === pluginId ); - if (evalConfig) { - setConfig(evalConfig.script_parameters); + if (generationConfig) { + setConfig(generationConfig.script_parameters); + const datasetKeyExists = Object.keys( - evalConfig.script_parameters + generationConfig.script_parameters ).some((key) => key.toLowerCase().includes('dataset')); + const docsKeyExists = Object.keys( - evalConfig.script_parameters + generationConfig.script_parameters ).some((key) => key.toLowerCase().includes('docs')); + const contextKeyExists = Object.keys( - evalConfig.script_parameters + generationConfig.script_parameters ).some((key) => key.toLowerCase().includes('context')); setHasDatasetKey(datasetKeyExists); - // setHasDocumentsKey(docsKeyExists); - // setHasContextKey(contextKeyExists); if ( docsKeyExists && - evalConfig.script_parameters.docs.length > 0 + generationConfig.script_parameters.docs.length > 0 ) { setHasContextKey(false); setHasDocumentsKey(true); + generationConfig.script_parameters.docs = generationConfig.script_parameters.docs.split(','); + setConfig(generationConfig.script_parameters); + setSelectedDocs(generationConfig.script_parameters.docs); - evalConfig.script_parameters.docs = evalConfig.script_parameters.docs.split(','); - - setConfig(evalConfig.script_parameters); } else if ( contextKeyExists && - evalConfig.script_parameters.context.length > 0 + generationConfig.script_parameters.context.length > 0 ) { setHasContextKey(true); setHasDocumentsKey(false); - const context = evalConfig.script_parameters.context; + const context = generationConfig.script_parameters.context; setContextInput(context); - delete evalConfig.script_parameters.context; - setConfig(evalConfig.script_parameters); + delete generationConfig.script_parameters.context; + setConfig(generationConfig.script_parameters); } + if ( - evalConfig.script_parameters._dataset_display_message && - evalConfig.script_parameters._dataset_display_message.length > - 0 + hasDatasetKey && + generationConfig.script_parameters.dataset_name.length > 0 ) { - setDatasetDisplayMessage( - evalConfig.script_parameters._dataset_display_message - ); + setSelectedDataset(generationConfig.script_parameters.dataset_name); } if ( - hasDatasetKey && - evalConfig.script_parameters.dataset_name.length > 0 + generationConfig.script_parameters._dataset_display_message && + generationConfig.script_parameters._dataset_display_message.length > + 0 ) { - setSelectedDataset(evalConfig.script_parameters.dataset_name); + setDatasetDisplayMessage( + generationConfig.script_parameters._dataset_display_message + ); } - if (!nameInput && evalConfig?.name.length > 0) { - setNameInput(evalConfig.name); + if (!nameInput && generationConfig?.name.length > 0) { + setNameInput(generationConfig.name); } } - // if (nameInput !== '' && evalConfig?.name) { - // setNameInput(evalConfig?.name); - // } - // setNameInput(evalConfig?.name); - // if (!nameInput && evalConfig?.script_parameters.run_name) { - // setNameInput(evalConfig.script_parameters.run_name); - // } } } catch (error) { - console.error('Failed to parse evaluations JSON string:', error); + console.error('Failed to parse generations JSON string:', error); } } } else { + // CREATE NEW GENERATION if (data) { let parsedData; try { parsedData = JSON.parse(data); //Parsing data for easy access to parameters} // Set config as a JSON object with keys of the parameters and values of the default values + setSelectedDocs([]); let tempconfig: { [key: string]: any } = {}; if (parsedData && parsedData.parameters) { tempconfig = Object.fromEntries( @@ -212,6 +208,7 @@ export default function GenerateModal({ value.default, ]) ); + // Logic to set dataset message if (parsedData && parsedData._dataset) { setHasDatasetKey(true); // Check if the dataset display message string length is greater than 0 @@ -224,41 +221,20 @@ export default function GenerateModal({ } } // Check if parsed data parameters has a key that includes 'docs' - if (parsedData && parsedData.parameters) { - const docsKeyExists = Object.keys(parsedData.parameters).some( - (key) => key.toLowerCase().includes('tflabcustomui_docs') - ); - if (docsKeyExists) { - // Delete the parameter key that includes 'docs' from the config - delete tempconfig[ - Object.keys(parsedData.parameters).find((key) => - key.toLowerCase().includes('tflabcustomui_docs') - ) - ]; - } - const contextKeyExists = Object.keys( - parsedData.parameters - ).some((key) => - key.toLowerCase().includes('tflabcustomui_context') - ); - if (contextKeyExists) { - // Delete the parameter key that includes 'context' from the config - delete tempconfig[ - Object.keys(parsedData.parameters).find((key) => - key.toLowerCase().includes('tflabcustomui_context') - ) - ]; - } - setHasContextKey(contextKeyExists); - setHasDocumentsKey(docsKeyExists); - } + const docsKeyExists = Object.keys(parsedData.parameters).some( + (key) => key.toLowerCase().includes('tflabcustomui_docs') + ); + + const contextKeyExists = Object.keys( + parsedData.parameters + ).some((key) => + key.toLowerCase().includes('tflabcustomui_context') + ); + setHasContextKey(contextKeyExists); + setHasDocumentsKey(docsKeyExists); } setConfig(tempconfig); - // Set hasDataset to true in the parsed data, the dataset key is `true` - // If tempconfig is not an empty object - // if (tempconfig && Object.keys(tempconfig).length > 0) { - // setNameInput(generateFriendlyName()); - // } + } catch (e) { console.error('Error parsing data', e); parsedData = ''; @@ -266,7 +242,7 @@ export default function GenerateModal({ } } } - }, [experimentInfo, pluginId, currentEvalName, nameInput, data]); + }, [experimentInfo, pluginId, currentGenerationName, nameInput, data]); if (!experimentInfo?.id) { return 'Select an Experiment'; @@ -276,7 +252,7 @@ export default function GenerateModal({ ? experimentInfo?.config?.foundation_filename : experimentInfo?.config?.foundation; - // Set config to the plugin config if it is available based on currentEvalName within experiment info + // Set config to the plugin config if it is available based on currentGenerationName within experiment info function TrainingModalFirstTab() { return ( @@ -328,7 +304,10 @@ export default function GenerateModal({ <PickADocumentMenu experimentInfo={experimentInfo} showFoldersOnly={false} - defaultValue={config.docs? config.docs : []} + value={selectedDocs} + onChange={setSelectedDocs} + defaultValue={config.docs ? config.docs : []} + // defaultValue={config.docs? config.docs : []} name="docs" /> <FormHelperText>Select documents to upload</FormHelperText> @@ -337,6 +316,7 @@ export default function GenerateModal({ ); } + function ContextTab({ contextInput, setContextInput }) { return ( <Stack spacing={2}> @@ -385,11 +365,11 @@ export default function GenerateModal({ console.log('formJson', formJson); - // Run when the currentEvalName is provided - if (currentEvalName && currentEvalName !== '') { + // Run when the currentGenerationName is provided + if (currentGenerationName && currentGenerationName !== '') { const result = await chatAPI.EXPERIMENT_EDIT_GENERATION( experimentInfo?.id, - currentEvalName, + currentGenerationName, formJson ); setNameInput(generateFriendlyName()); @@ -409,12 +389,6 @@ export default function GenerateModal({ experimentInfoMutate(); onClose(); - // }; - // } - // const result = await chatAPI.EXPERIMENT_EDIT_EVALUATION(experimentInfo?.id, currentEvalName, formJson) - // // alert(JSON.stringify(formJson, null, 2)); - // setNameInput(generateFriendlyName()); - // onClose(); } catch (error) { console.error('Failed to edit generation:', error); } @@ -434,7 +408,7 @@ export default function GenerateModal({ }} > <form - id="evaluation-form" + id="generation-form" style={{ display: 'flex', flexDirection: 'column', @@ -444,7 +418,7 @@ export default function GenerateModal({ onSubmit={handleSubmit} > <Tabs - aria-label="evaluation Template Tabs" + aria-label="generation Template Tabs" value={currentTab} onChange={(event, newValue) => setCurrentTab(newValue)} sx={{ borderRadius: 'lg', display: 'flex', overflow: 'hidden' }} diff --git a/src/renderer/components/Experiment/Generate/GenerateTasksTable.tsx b/src/renderer/components/Experiment/Generate/GenerateTasksTable.tsx index 77e217e2d554100b938c14bead51676bf09bdad2..64f6c9129b4803574d960c033b6dc1063420709f 100644 --- a/src/renderer/components/Experiment/Generate/GenerateTasksTable.tsx +++ b/src/renderer/components/Experiment/Generate/GenerateTasksTable.tsx @@ -4,10 +4,10 @@ import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; import { useState } from 'react'; import useSWR from 'swr'; -function listEvals(evalString) { +function listGenerations(generationString) { let result = []; - if (evalString) { - result = JSON.parse(evalString); + if (generationString) { + result = JSON.parse(generationString); } return result; } @@ -47,10 +47,10 @@ function formatTemplateConfig(script_parameters): ReactElement { ); } -async function evaluationRun( +async function generationRun( experimentId: string, plugin: string, - evaluator: string + generator: string ) { // fetch( // chatAPI.Endpoints.Experiment.RunGeneration(experimentId, plugin, evaluator) @@ -62,7 +62,7 @@ async function evaluationRun( 'QUEUED', JSON.stringify({ plugin: plugin, - generator: evaluator, + generator: generator, }) ) ); @@ -73,7 +73,7 @@ export default function GenerateTasksTable({ experimentInfo, experimentInfoMutate, setCurrentPlugin, - setCurrentEvalName, + setCurrentGenerationName, setOpen, }) { @@ -91,8 +91,8 @@ export default function GenerateTasksTable({ </tr> </thead> <tbody> - {listEvals(experimentInfo?.config?.generations) && - listEvals(experimentInfo?.config?.generations)?.map( + {listGenerations(experimentInfo?.config?.generations) && + listGenerations(experimentInfo?.config?.generations)?.map( (generations) => ( <tr key={generations.name}> <td style={{ overflow: 'hidden', paddingLeft: '1rem' }}> @@ -114,7 +114,7 @@ export default function GenerateTasksTable({ variant="soft" color="success" onClick={async () => - await evaluationRun( + await generationRun( experimentInfo.id, generations.plugin, generations.name @@ -128,7 +128,7 @@ export default function GenerateTasksTable({ onClick={() => { setOpen(true); setCurrentPlugin(generations?.plugin); - setCurrentEvalName(generations.name); + setCurrentGenerationName(generations.name); }} > Edit diff --git a/src/renderer/components/Experiment/Generate/ResultsModal.tsx b/src/renderer/components/Experiment/Generate/ResultsModal.tsx index fff5937d23b70028154e86250c9dcc68f5fc993a..da7c5d5ed6136c08508b0f30c923ae58a4d45122 100644 --- a/src/renderer/components/Experiment/Generate/ResultsModal.tsx +++ b/src/renderer/components/Experiment/Generate/ResultsModal.tsx @@ -13,18 +13,18 @@ export default function ResultsModal({ setOpen, experimentInfo, plugin, - evaluator, + generator, }) { const [resultText, setResultText] = useState(''); useEffect(() => { - if (open && experimentInfo && evaluator) { + if (open && experimentInfo && generator) { const output_file = `plugins/${plugin}/output.txt`; console.log('Fetching results from', output_file); fetch( chatAPI.Endpoints.Experiment.GetGenerationOutput( experimentInfo?.id, - evaluator + generator ) ).then((res) => { if (res.ok) { @@ -47,7 +47,7 @@ export default function ResultsModal({ }} > <ModalClose /> - <DialogTitle>Results from: {evaluator}</DialogTitle> + <DialogTitle>Results from: {generator}</DialogTitle> <DialogContent sx={{ backgroundColor: '#222', color: '#ddd', padding: 2 }} > diff --git a/src/renderer/components/Experiment/Rag/PickADocumentMenu.tsx b/src/renderer/components/Experiment/Rag/PickADocumentMenu.tsx index 7b88e1f433647a92b1341d6b7382470f8aa387e1..1bd7e0f353b34db8cb8652bf14124b0c247b713d 100644 --- a/src/renderer/components/Experiment/Rag/PickADocumentMenu.tsx +++ b/src/renderer/components/Experiment/Rag/PickADocumentMenu.tsx @@ -7,6 +7,8 @@ const fetcher = (url) => fetch(url).then((res) => res.json()); export default function PickADocumentMenu({ name, experimentInfo, + value, + onChange, defaultValue = [], showFoldersOnly = false, }) { @@ -16,20 +18,21 @@ export default function PickADocumentMenu({ mutate, } = useSWR(chatAPI.Endpoints.Documents.List(experimentInfo?.id, ''), fetcher); - const [selected, setSelected] = useState([]); useEffect(() => { - setSelected(defaultValue || []); + if (defaultValue.length > 0) { + onChange(defaultValue); + } }, [defaultValue]); function handleChange(event, newValue) { console.log(newValue); - setSelected(newValue); + onChange(newValue); } return ( - <Select multiple onChange={handleChange} value={selected} name={name}> + <Select multiple onChange={handleChange} value={value} name={name}> {rows?.map((row) => showFoldersOnly ? ( row?.type === 'folder' && (