From f9127eb9f79592b0ad9e134935cbddfe7a7ee3bd Mon Sep 17 00:00:00 2001
From: Ali Asaria <ali.asaria@gmail.com>
Date: Tue, 7 May 2024 11:17:31 -0400
Subject: [PATCH] allow creation of custom prompts

---
 .../Interact/TemplatedCompletion.tsx          | 112 +++++++++++++-----
 .../Interact/TemplatedPromptModal.tsx         |  93 +++++++++++++++
 src/renderer/lib/transformerlab-api-sdk.ts    |   4 +-
 3 files changed, 179 insertions(+), 30 deletions(-)
 create mode 100644 src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx

diff --git a/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx b/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx
index 7c999394..8a41131d 100644
--- a/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx
+++ b/src/renderer/components/Experiment/Interact/TemplatedCompletion.tsx
@@ -16,7 +16,7 @@ import {
   TabPanel,
   LinearProgress,
 } from '@mui/joy';
-import { SendIcon, PlusCircleIcon } from 'lucide-react';
+import { SendIcon, PlusCircleIcon, X, XIcon } from 'lucide-react';
 import { useState } from 'react';
 
 import Markdown from 'react-markdown';
@@ -25,18 +25,23 @@ import remarkGfm from 'remark-gfm';
 import useSWR from 'swr';
 
 import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
+import TemplatedPromptModal from './TemplatedPromptModal';
 
 const fetcher = (url) => fetch(url).then((res) => res.json());
 
 export default function TemplatedCompletion({ experimentInfo }) {
-  const [selectedTemplate, setSelectedTemplate] = useState(null);
+  const [selectedTemplate, setSelectedTemplate] = useState<any | null>(null);
   const [showTemplate, setShowTemplate] = useState(false);
   const [isThinking, setIsThinking] = useState(false);
   const [timeTaken, setTimeTaken] = useState<number | null>(null);
   const [outputText, setOutputText] = useState('');
   const [currentTab, setCurrentTab] = useState(0);
+  const [editTemplateModalOpen, setEditTemplateModalOpen] = useState(false);
 
-  const { data: templates } = useSWR(chatAPI.Endpoints.Prompts.List(), fetcher);
+  const { data: templates, mutate: templatesMutate } = useSWR(
+    chatAPI.Endpoints.Prompts.List(),
+    fetcher
+  );
 
   const sendTemplatedCompletionToLLM = async (element, target) => {
     if (!selectedTemplate) {
@@ -45,7 +50,7 @@ export default function TemplatedCompletion({ experimentInfo }) {
 
     const text = element.value;
 
-    const template = templates.find((t) => t.id === selectedTemplate);
+    const template = selectedTemplate;
 
     if (!template) {
       alert('Template not found');
@@ -108,6 +113,11 @@ export default function TemplatedCompletion({ experimentInfo }) {
         paddingTop: '1rem',
       }}
     >
+      <TemplatedPromptModal
+        open={editTemplateModalOpen}
+        setOpen={setEditTemplateModalOpen}
+        mutate={templatesMutate}
+      />
       <div>
         {/* {JSON.stringify(templates)} */}
         <FormLabel>Prompt Template:</FormLabel>
@@ -115,14 +125,17 @@ export default function TemplatedCompletion({ experimentInfo }) {
           placeholder="Select Template"
           variant="soft"
           name="template"
-          value={selectedTemplate}
+          value={selectedTemplate?.id}
           onChange={(e, newValue) => {
             if (newValue === 'custom') {
               setSelectedTemplate(null);
-              alert('Custom template creation not implemented yet');
+              setEditTemplateModalOpen(true);
               return;
             }
-            setSelectedTemplate(newValue);
+            const newSelectedTemplate = templates?.find(
+              (t) => t.id === newValue
+            );
+            setSelectedTemplate(newSelectedTemplate);
           }}
           renderValue={(selected) => {
             const value = selected?.value;
@@ -135,7 +148,12 @@ export default function TemplatedCompletion({ experimentInfo }) {
         >
           {templates?.map((template) => (
             <Option key={template.id} value={template.id}>
-              <Chip color="warning">gallery</Chip>
+              {template?.source !== 'local' && (
+                <Chip color="warning">gallery</Chip>
+              )}
+              {template?.source == 'local' && (
+                <Chip color="success">local</Chip>
+              )}
               {template.title}
             </Option>
           ))}
@@ -146,19 +164,53 @@ export default function TemplatedCompletion({ experimentInfo }) {
       </div>
       {selectedTemplate && (
         <>
-          <Typography
-            level="body-xs"
-            onClick={() => {
-              setShowTemplate(!showTemplate);
-            }}
+          <Stack
+            direction="row"
             sx={{
-              cursor: 'pointer',
-              color: 'primary',
-              textAlign: 'right',
+              justifyContent: 'flex-end',
+              gap: '1rem',
             }}
           >
-            {showTemplate ? 'Hide Template' : 'Show Template'}
-          </Typography>
+            <Typography
+              level="body-xs"
+              onClick={() => {
+                setShowTemplate(!showTemplate);
+              }}
+              sx={{
+                cursor: 'pointer',
+                color: 'primary',
+                textAlign: 'right',
+              }}
+            >
+              {showTemplate ? 'Hide' : 'Show'}
+            </Typography>
+            {selectedTemplate?.source == 'local' && (
+              <Typography
+                color="warning"
+                level="body-xs"
+                onClick={async () => {
+                  if (!selectedTemplate) {
+                    return;
+                  }
+                  if (
+                    confirm('Are you sure you want to delete this template?')
+                  ) {
+                    await fetch(
+                      chatAPI.Endpoints.Prompts.Delete(selectedTemplate.id)
+                    );
+                    templatesMutate();
+                  }
+                }}
+                sx={{
+                  cursor: 'pointer',
+                  color: 'primary',
+                  textAlign: 'right',
+                }}
+              >
+                Delete
+              </Typography>
+            )}
+          </Stack>
           {showTemplate && (
             <>
               <Sheet
@@ -179,9 +231,7 @@ export default function TemplatedCompletion({ experimentInfo }) {
                       fontFamily: 'var(--joy-fontFamily-code)',
                     }}
                   >
-                    {selectedTemplate
-                      ? templates?.find((t) => t.id === selectedTemplate)?.text
-                      : ''}
+                    {selectedTemplate ? selectedTemplate?.text : ''}
                   </pre>
                 </Typography>
               </Sheet>
@@ -295,10 +345,12 @@ export default function TemplatedCompletion({ experimentInfo }) {
               </TabList>
               <TabPanel value={0} keepMounted>
                 <Box
-                  sx={{
-                    paddingLeft: 2,
-                    borderLeft: '2px solid var(--joy-palette-neutral-500)',
-                  }}
+                  sx={
+                    {
+                      // paddingLeft: 2,
+                      // borderLeft: '2px solid var(--joy-palette-neutral-500)',
+                    }
+                  }
                 >
                   <Textarea name="output-text" variant="plain"></Textarea>
                   {isThinking && <LinearProgress sx={{ width: '300px' }} />}
@@ -306,10 +358,12 @@ export default function TemplatedCompletion({ experimentInfo }) {
               </TabPanel>
               <TabPanel value={1} keepMounted>
                 <Box
-                  sx={{
-                    paddingLeft: 2,
-                    borderLeft: '2px solid var(--joy-palette-neutral-500)',
-                  }}
+                  sx={
+                    {
+                      // paddingLeft: 2,
+                      // borderLeft: '2px solid var(--joy-palette-neutral-500)',
+                    }
+                  }
                 >
                   {isThinking && <LinearProgress sx={{ width: '300px' }} />}
                   <Markdown
diff --git a/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx b/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx
new file mode 100644
index 00000000..8857e1f1
--- /dev/null
+++ b/src/renderer/components/Experiment/Interact/TemplatedPromptModal.tsx
@@ -0,0 +1,93 @@
+import {
+  Button,
+  DialogContent,
+  DialogTitle,
+  FormControl,
+  FormHelperText,
+  FormLabel,
+  Input,
+  Modal,
+  ModalClose,
+  ModalDialog,
+  Stack,
+  Textarea,
+} from '@mui/joy';
+import React, { useState } from 'react';
+
+import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
+
+export default function TemplatedPromptModal({ open, setOpen, mutate }) {
+  return (
+    <Modal open={open}>
+      <ModalDialog sx={{ minWidth: '500px' }}>
+        <DialogTitle>Create New Prompt</DialogTitle>
+        <ModalClose
+          onClick={() => {
+            setOpen(false);
+          }}
+        />
+        {/* <DialogContent>Fill in the information of the project.</DialogContent> */}
+        <form
+          onSubmit={async (event: React.FormEvent<HTMLFormElement>) => {
+            event.preventDefault();
+
+            const formData = new FormData(event.currentTarget);
+            const promptName = formData.get('name') as string;
+            const template = formData.get('template') as string;
+
+            const response = await fetch(chatAPI.Endpoints.Prompts.New(), {
+              method: 'POST',
+              headers: {
+                'Content-Type': 'application/json',
+              },
+              body: JSON.stringify({
+                title: promptName,
+                text: template,
+              }),
+            });
+
+            const responseJSON = await response.json();
+
+            if (responseJSON?.status == 'error') {
+              alert(responseJSON?.message);
+              return;
+            }
+
+            mutate();
+
+            setOpen(false);
+          }}
+        >
+          <Stack spacing={2}>
+            <FormControl>
+              <FormLabel>Name</FormLabel>
+              <Input
+                name="name"
+                autoFocus
+                required
+                placeholder="My New Prompt"
+              />
+            </FormControl>
+            <FormControl>
+              <FormLabel>Template</FormLabel>
+              <Textarea
+                name="template"
+                required
+                minRows={4}
+                placeholder="Summarize the following sentence:
+{text}
+Answer:
+"
+              />
+              <FormHelperText>
+                Use &#123;text&#125; as a placeholder for the place where the
+                provided text will be inserted
+              </FormHelperText>
+            </FormControl>
+            <Button type="submit">Submit</Button>
+          </Stack>
+        </form>
+      </ModalDialog>
+    </Modal>
+  );
+}
diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts
index 271cf47a..b66dbf19 100644
--- a/src/renderer/lib/transformerlab-api-sdk.ts
+++ b/src/renderer/lib/transformerlab-api-sdk.ts
@@ -590,7 +590,7 @@ Endpoints.Models = {
     API_URL() + 'model/get_local_hfconfig?model_id=' + modelId,
   GetHFCacheModelList: (uninstalled_only: boolean = true) =>
     API_URL() + 'model/hfcache_list?uninstalled_only=' + uninstalled_only,
-  ImportFromHFCache: (modelId: string) => 
+  ImportFromHFCache: (modelId: string) =>
     API_URL() + 'model/hfcache_import?model_id=' + modelId,
   HuggingFaceLogin: () => API_URL() + 'model/login_to_huggingface',
   Delete: (modelId: string) => API_URL() + 'model/delete?model_id=' + modelId,
@@ -642,6 +642,8 @@ Endpoints.Rag = {
 
 Endpoints.Prompts = {
   List: () => API_URL() + 'prompts/list',
+  New: () => API_URL() + 'prompts/new',
+  Delete: (promptId: string) => API_URL() + 'prompts/delete/' + promptId,
 };
 
 export function GET_TRAINING_TEMPLATE_URL() {
-- 
GitLab