diff --git a/src/renderer/components/Experiment/Train/TraningModalDataTab.tsx b/src/renderer/components/Experiment/Train/TraningModalDataTab.tsx index 1936d24691368e0e80c5444323708c3f97f17684..36febfc7a16484704a4bfd37d1984689776cde00 100644 --- a/src/renderer/components/Experiment/Train/TraningModalDataTab.tsx +++ b/src/renderer/components/Experiment/Train/TraningModalDataTab.tsx @@ -13,6 +13,7 @@ import { } from '@mui/joy'; import useSWR from 'swr'; import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; +import { parse } from 'path'; const fetcher = (url) => fetch(url).then((res) => res.json()); @@ -39,6 +40,107 @@ export default function TrainingModalDataTab({ fetcher ); + function renderTemplate(templateType: string) { + switch (templateType) { + case 'alpaca': + return ( + <> + <FormControl> + <FormLabel>Instruction</FormLabel> + <Textarea + required + name="instruction_template" + id="instruction" + defaultValue={ + templateData + ? templateData?.config?.instruction_template + : 'Instruction: {{instruction}}' + } + rows={5} + /> + <FormHelperText> + The instruction (usually the system message) to send to the + model. For example in a summarization task, this could be + "Summarize the following text:" + </FormHelperText> + </FormControl> + <br /> + <FormControl> + <FormLabel>Input</FormLabel> + <Textarea + required + name="input_template" + id="Input" + defaultValue={ + templateData + ? templateData?.config?.input_template + : '{{input}}' + } + rows={5} + /> + </FormControl> + <FormHelperText> + The input to send to the model. For example in a summarization + task, this could be the text to summarize. + </FormHelperText> + <br /> + <FormControl> + <FormLabel>Output</FormLabel> + <Textarea + required + name="output_template" + id="output" + defaultValue={ + templateData + ? templateData?.config?.output_template + : '{{output}}' + } + rows={5} + /> + <FormHelperText> + The output to expect from the model. For example in a + summarization task this could be the expected summary of the + input text. + </FormHelperText> + </FormControl> + </> + ); + case 'none': + return <> </>; + default: + return ( + <FormControl> + <textarea + required + name="formatting_template" + id="formatting_template" + defaultValue={ + templateData + ? templateData?.config?.formatting_template + : 'Instruction: {{instruction}}\nPrompt: {{prompt}}\nGeneration: {{generation}}' + } + rows={5} + /> + <FormHelperText + sx={{ flexDirection: 'column', alignItems: 'flex-start' }} + > + This describes how the data is formatted when passed to the + trainer. Use Jinja2 Standard String Templating format. For + example: + <br /> + <span style={{}}> + Summarize the following: + <br /> + Prompt: {{prompt}} + <br /> + Generation: {{generation}} + </span> + </FormHelperText> + </FormControl> + ); + } + } + const parsedData = data ? JSON.parse(data) : null; return ( <> @@ -66,132 +168,45 @@ export default function TrainingModalDataTab({ <br /> {selectedDataset && ( <> - <FormControl> - <FormLabel>Available Fields</FormLabel> + {parsedData?.training_template_format !== 'none' && ( + <> + <FormControl> + <FormLabel>Available Fields</FormLabel> - <Box sx={{ display: 'flex', gap: '4px', flexWrap: 'wrap' }}> - {currentDatasetInfoIsLoading && <CircularProgress />} - {/* // For each key in the currentDatasetInfo.features object, + <Box sx={{ display: 'flex', gap: '4px', flexWrap: 'wrap' }}> + {currentDatasetInfoIsLoading && <CircularProgress />} + {/* // For each key in the currentDatasetInfo.features object, display it: */} - {currentDatasetInfo?.features && - Object.keys(currentDatasetInfo?.features).map((key) => ( - <> - <Chip - onClick={() => { - injectIntoTemplate(key); - }} - > - {key} - </Chip> - - </> - ))} - </Box> + {currentDatasetInfo?.features && + Object.keys(currentDatasetInfo?.features).map((key) => ( + <> + <Chip + onClick={() => { + injectIntoTemplate(key); + }} + > + {key} + </Chip> + + </> + ))} + </Box> - {selectedDataset && ( - <FormHelperText> - Use the field names above, surrounded by - {{}} in the template below - </FormHelperText> - )} - </FormControl> - <Divider sx={{ mt: '1rem', mb: '2rem' }} /> - <Typography level="title-sm" pb={2}> - Template - </Typography> - {parsedData?.training_template_format == 'alpaca' ? ( - <> - <FormControl> - <FormLabel>Instruction</FormLabel> - <Textarea - required - name="instruction_template" - id="instruction" - defaultValue={ - templateData - ? templateData?.config?.instruction_template - : 'Instruction: {{instruction}}' - } - rows={5} - /> - <FormHelperText> - The instruction (usually the system message) to send to the - model. For example in a summarization task, this could be - "Summarize the following text:" - </FormHelperText> - </FormControl> - <br /> - <FormControl> - <FormLabel>Input</FormLabel> - <Textarea - required - name="input_template" - id="Input" - defaultValue={ - templateData - ? templateData?.config?.input_template - : '{{input}}' - } - rows={5} - /> - </FormControl> - <FormHelperText> - The input to send to the model. For example in a summarization - task, this could be the text to summarize. - </FormHelperText> - <br /> - <FormControl> - <FormLabel>Output</FormLabel> - <Textarea - required - name="output_template" - id="output" - defaultValue={ - templateData - ? templateData?.config?.output_template - : '{{output}}' - } - rows={5} - /> - <FormHelperText> - The output to expect from the model. For example in a - summarization task this could be the expected summary of the - input text. - </FormHelperText> - </FormControl> - </> - ) : ( - <> - <FormControl> - <textarea - required - name="formatting_template" - id="formatting_template" - defaultValue={ - templateData - ? templateData?.config?.formatting_template - : 'Instruction: {{instruction}}\nPrompt: {{prompt}}\nGeneration: {{generation}}' - } - rows={5} - /> - <FormHelperText - sx={{ flexDirection: 'column', alignItems: 'flex-start' }} - > - This describes how the data is formatted when passed to the - trainer. Use Jinja2 Standard String Templating format. For - example: - <br /> - <span style={{}}> - Summarize the following: - <br /> - Prompt: {{prompt}} - <br /> - Generation: {{generation}} - </span> - </FormHelperText> + {selectedDataset && ( + <FormHelperText> + Use the field names above, surrounded by + {{}} in the template below + </FormHelperText> + )} </FormControl> + <Divider sx={{ mt: '1rem', mb: '2rem' }} /> + <Typography level="title-sm" pb={2}> + Template + </Typography> </> )} + + {renderTemplate(parsedData?.training_template_format)} </> )} </>