From f9127eb9f79592b0ad9e134935cbddfe7a7ee3bd Mon Sep 17 00:00:00 2001 From: Ali Asaria <ali.asaria@gmail.com> Date: Tue, 7 May 2024 11:17:31 -0400 Subject: [PATCH] allow creation of custom prompts --- .../Interact/TemplatedCompletion.tsx | 112 +++++++++++++----- .../Interact/TemplatedPromptModal.tsx | 93 +++++++++++++++ src/renderer/lib/transformerlab-api-sdk.ts | 4 +- 3 files changed, 179 insertions(+), 30 deletions(-) create mode 100644 src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx diff --git a/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx b/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx index 7c999394..8a41131d 100644 --- a/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx +++ b/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx @@ -16,7 +16,7 @@ import { TabPanel, LinearProgress, } from '@mui/joy'; -import { SendIcon, PlusCircleIcon } from 'lucide-react'; +import { SendIcon, PlusCircleIcon, X, XIcon } from 'lucide-react'; import { useState } from 'react'; import Markdown from 'react-markdown'; @@ -25,18 +25,23 @@ import remarkGfm from 'remark-gfm'; import useSWR from 'swr'; import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; +import TemplatedPromptModal from './TemplatedPromptModal'; const fetcher = (url) => fetch(url).then((res) => res.json()); export default function TemplatedCompletion({ experimentInfo }) { - const [selectedTemplate, setSelectedTemplate] = useState(null); + const [selectedTemplate, setSelectedTemplate] = useState<any | null>(null); const [showTemplate, setShowTemplate] = useState(false); const [isThinking, setIsThinking] = useState(false); const [timeTaken, setTimeTaken] = useState<number | null>(null); const [outputText, setOutputText] = useState(''); const [currentTab, setCurrentTab] = useState(0); + const [editTemplateModalOpen, setEditTemplateModalOpen] = useState(false); - const { data: templates } = useSWR(chatAPI.Endpoints.Prompts.List(), fetcher); + const { data: templates, mutate: templatesMutate } = useSWR( + chatAPI.Endpoints.Prompts.List(), + fetcher + ); const sendTemplatedCompletionToLLM = async (element, target) => { if (!selectedTemplate) { @@ -45,7 +50,7 @@ export default function TemplatedCompletion({ experimentInfo }) { const text = element.value; - const template = templates.find((t) => t.id === selectedTemplate); + const template = selectedTemplate; if (!template) { alert('Template not found'); @@ -108,6 +113,11 @@ export default function TemplatedCompletion({ experimentInfo }) { paddingTop: '1rem', }} > + <TemplatedPromptModal + open={editTemplateModalOpen} + setOpen={setEditTemplateModalOpen} + mutate={templatesMutate} + /> <div> {/* {JSON.stringify(templates)} */} <FormLabel>Prompt Template:</FormLabel> @@ -115,14 +125,17 @@ export default function TemplatedCompletion({ experimentInfo }) { placeholder="Select Template" variant="soft" name="template" - value={selectedTemplate} + value={selectedTemplate?.id} onChange={(e, newValue) => { if (newValue === 'custom') { setSelectedTemplate(null); - alert('Custom template creation not implemented yet'); + setEditTemplateModalOpen(true); return; } - setSelectedTemplate(newValue); + const newSelectedTemplate = templates?.find( + (t) => t.id === newValue + ); + setSelectedTemplate(newSelectedTemplate); }} renderValue={(selected) => { const value = selected?.value; @@ -135,7 +148,12 @@ export default function TemplatedCompletion({ experimentInfo }) { > {templates?.map((template) => ( <Option key={template.id} value={template.id}> - <Chip color="warning">gallery</Chip> + {template?.source !== 'local' && ( + <Chip color="warning">gallery</Chip> + )} + {template?.source == 'local' && ( + <Chip color="success">local</Chip> + )} {template.title} </Option> ))} @@ -146,19 +164,53 @@ export default function TemplatedCompletion({ experimentInfo }) { </div> {selectedTemplate && ( <> - <Typography - level="body-xs" - onClick={() => { - setShowTemplate(!showTemplate); - }} + <Stack + direction="row" sx={{ - cursor: 'pointer', - color: 'primary', - textAlign: 'right', + justifyContent: 'flex-end', + gap: '1rem', }} > - {showTemplate ? 'Hide Template' : 'Show Template'} - </Typography> + <Typography + level="body-xs" + onClick={() => { + setShowTemplate(!showTemplate); + }} + sx={{ + cursor: 'pointer', + color: 'primary', + textAlign: 'right', + }} + > + {showTemplate ? 'Hide' : 'Show'} + </Typography> + {selectedTemplate?.source == 'local' && ( + <Typography + color="warning" + level="body-xs" + onClick={async () => { + if (!selectedTemplate) { + return; + } + if ( + confirm('Are you sure you want to delete this template?') + ) { + await fetch( + chatAPI.Endpoints.Prompts.Delete(selectedTemplate.id) + ); + templatesMutate(); + } + }} + sx={{ + cursor: 'pointer', + color: 'primary', + textAlign: 'right', + }} + > + Delete + </Typography> + )} + </Stack> {showTemplate && ( <> <Sheet @@ -179,9 +231,7 @@ export default function TemplatedCompletion({ experimentInfo }) { fontFamily: 'var(--joy-fontFamily-code)', }} > - {selectedTemplate - ? templates?.find((t) => t.id === selectedTemplate)?.text - : ''} + {selectedTemplate ? selectedTemplate?.text : ''} </pre> </Typography> </Sheet> @@ -295,10 +345,12 @@ export default function TemplatedCompletion({ experimentInfo }) { </TabList> <TabPanel value={0} keepMounted> <Box - sx={{ - paddingLeft: 2, - borderLeft: '2px solid var(--joy-palette-neutral-500)', - }} + sx={ + { + // paddingLeft: 2, + // borderLeft: '2px solid var(--joy-palette-neutral-500)', + } + } > <Textarea name="output-text" variant="plain"></Textarea> {isThinking && <LinearProgress sx={{ width: '300px' }} />} @@ -306,10 +358,12 @@ export default function TemplatedCompletion({ experimentInfo }) { </TabPanel> <TabPanel value={1} keepMounted> <Box - sx={{ - paddingLeft: 2, - borderLeft: '2px solid var(--joy-palette-neutral-500)', - }} + sx={ + { + // paddingLeft: 2, + // borderLeft: '2px solid var(--joy-palette-neutral-500)', + } + } > {isThinking && <LinearProgress sx={{ width: '300px' }} />} <Markdown diff --git a/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx b/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx new file mode 100644 index 00000000..8857e1f1 --- /dev/null +++ b/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx @@ -0,0 +1,93 @@ +import { + Button, + DialogContent, + DialogTitle, + FormControl, + FormHelperText, + FormLabel, + Input, + Modal, + ModalClose, + ModalDialog, + Stack, + Textarea, +} from '@mui/joy'; +import React, { useState } from 'react'; + +import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; + +export default function TemplatedPromptModal({ open, setOpen, mutate }) { + return ( + <Modal open={open}> + <ModalDialog sx={{ minWidth: '500px' }}> + <DialogTitle>Create New Prompt</DialogTitle> + <ModalClose + onClick={() => { + setOpen(false); + }} + /> + {/* <DialogContent>Fill in the information of the project.</DialogContent> */} + <form + onSubmit={async (event: React.FormEvent<HTMLFormElement>) => { + event.preventDefault(); + + const formData = new FormData(event.currentTarget); + const promptName = formData.get('name') as string; + const template = formData.get('template') as string; + + const response = await fetch(chatAPI.Endpoints.Prompts.New(), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + title: promptName, + text: template, + }), + }); + + const responseJSON = await response.json(); + + if (responseJSON?.status == 'error') { + alert(responseJSON?.message); + return; + } + + mutate(); + + setOpen(false); + }} + > + <Stack spacing={2}> + <FormControl> + <FormLabel>Name</FormLabel> + <Input + name="name" + autoFocus + required + placeholder="My New Prompt" + /> + </FormControl> + <FormControl> + <FormLabel>Template</FormLabel> + <Textarea + name="template" + required + minRows={4} + placeholder="Summarize the following sentence: +{text} +Answer: +" + /> + <FormHelperText> + Use {text} as a placeholder for the place where the + provided text will be inserted + </FormHelperText> + </FormControl> + <Button type="submit">Submit</Button> + </Stack> + </form> + </ModalDialog> + </Modal> + ); +} diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts index 271cf47a..b66dbf19 100644 --- a/src/renderer/lib/transformerlab-api-sdk.ts +++ b/src/renderer/lib/transformerlab-api-sdk.ts @@ -590,7 +590,7 @@ Endpoints.Models = { API_URL() + 'model/get_local_hfconfig?model_id=' + modelId, GetHFCacheModelList: (uninstalled_only: boolean = true) => API_URL() + 'model/hfcache_list?uninstalled_only=' + uninstalled_only, - ImportFromHFCache: (modelId: string) => + ImportFromHFCache: (modelId: string) => API_URL() + 'model/hfcache_import?model_id=' + modelId, HuggingFaceLogin: () => API_URL() + 'model/login_to_huggingface', Delete: (modelId: string) => API_URL() + 'model/delete?model_id=' + modelId, @@ -642,6 +642,8 @@ Endpoints.Rag = { Endpoints.Prompts = { List: () => API_URL() + 'prompts/list', + New: () => API_URL() + 'prompts/new', + Delete: (promptId: string) => API_URL() + 'prompts/delete/' + promptId, }; export function GET_TRAINING_TEMPLATE_URL() { -- GitLab