From 83df3d9e65cf3a0133c9fc7f0aef984038cf09db Mon Sep 17 00:00:00 2001
From: deep1401 <gandhi0869@gmail.com>
Date: Mon, 24 Feb 2025 14:49:10 -0800
Subject: [PATCH] Changes for providing model providers in settings

---
 .../components/AIProvidersSettings.tsx        | 298 ++++++++++++++++++
 .../components/TransformerLabSettings.tsx     |  73 +++--
 src/renderer/lib/transformerlab-api-sdk.ts    |   3 +
 3 files changed, 343 insertions(+), 31 deletions(-)
 create mode 100644 src/renderer/components/AIProvidersSettings.tsx

diff --git a/src/renderer/components/AIProvidersSettings.tsx b/src/renderer/components/AIProvidersSettings.tsx
new file mode 100644
index 00000000..5961ff84
--- /dev/null
+++ b/src/renderer/components/AIProvidersSettings.tsx
@@ -0,0 +1,298 @@
+import * as React from 'react';
+import {
+    Button,
+    Modal,
+    ModalDialog,
+    FormControl,
+    FormLabel,
+    Input,
+    List,
+    ListItem,
+    Box,
+    Typography
+} from '@mui/joy';
+import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
+import useSWR from 'swr';
+
+const fetcher = (url: string) => fetch(url).then((res) => res.json());
+
+interface Provider {
+    name: string;
+    keyName: string;
+    setKeyEndpoint: () => string;
+    checkKeyEndpoint: () => string;
+}
+
+const providers: Provider[] = [
+    {
+        name: 'OpenAI',
+        keyName: 'OPENAI_API_KEY',
+        setKeyEndpoint: () => chatAPI.Endpoints.Models.SetOpenAIKey(),
+        checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckOpenAIAPIKey(),
+    },
+    {
+        name: 'Claude',
+        keyName: 'CLAUDE_API_KEY',
+        setKeyEndpoint: () => chatAPI.Endpoints.Models.SetClaudeKey(),
+        checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckClaudeAPIKey(),
+    },
+    {
+        name: 'Custom API',
+        keyName: 'CUSTOM_MODEL_API_KEY',
+        setKeyEndpoint: () => chatAPI.Endpoints.Models.SetCustomAPIKey(),
+        checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckCustomAPIKey(),
+    }
+];
+
+interface AIProvidersSettingsProps {
+    onBack?: () => void;
+}
+
+export default function AIProvidersSettings({ onBack }: AIProvidersSettingsProps) {
+    const { data: openaiApiKey, mutate: mutateOpenAI } = useSWR(
+        chatAPI.Endpoints.Config.Get('OPENAI_API_KEY'),
+        fetcher
+    );
+    const { data: claudeApiKey, mutate: mutateClaude } = useSWR(
+        chatAPI.Endpoints.Config.Get('CLAUDE_API_KEY'),
+        fetcher
+    );
+    const { data: customAPIStatus, mutate: mutateCustom } = useSWR(
+        chatAPI.Endpoints.Config.Get('CUSTOM_MODEL_API_KEY'),
+        fetcher
+    );
+
+    const getProviderStatus = (provider: Provider) => {
+        if (provider.name === 'OpenAI') return openaiApiKey;
+        if (provider.name === 'Claude') return claudeApiKey;
+        if (provider.name === 'Custom API') return customAPIStatus;
+        return null;
+    };
+
+    const setProviderStatus = async (provider: Provider, token: string) => {
+        await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, token));
+        await fetch(provider.setKeyEndpoint());
+        const response = await fetch(provider.checkKeyEndpoint());
+        const result = await response.json();
+        return result.message === 'OK';
+    };
+
+    const [dialogOpen, setDialogOpen] = React.useState(false);
+    const [selectedProvider, setSelectedProvider] = React.useState<Provider | null>(null);
+    const [apiKey, setApiKey] = React.useState('');
+    const [hoveredProvider, setHoveredProvider] = React.useState<string | null>(null);
+    // States for custom API additional fields
+    const [customApiName, setCustomApiName] = React.useState('');
+    const [customBaseURL, setCustomBaseURL] = React.useState('');
+    const [customApiKey, setCustomApiKey] = React.useState('');
+    const [customModelName, setCustomModelName] = React.useState('');
+
+    const handleConnectClick = (provider: Provider) => {
+        setSelectedProvider(provider);
+        if (provider.name === 'Custom API') {
+            setCustomApiName('');
+            setCustomBaseURL('');
+            setCustomApiKey('');
+            setCustomModelName('');
+        } else {
+            setApiKey('');
+        }
+        setDialogOpen(true);
+    };
+
+    const handleDisconnectClick = async (provider: Provider) => {
+        await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, ''));
+        if (provider.name === 'OpenAI') {
+            mutateOpenAI();
+        } else if (provider.name === 'Claude') {
+            mutateClaude();
+        } else if (provider.name === 'Custom API') {
+            mutateCustom();
+        }
+    };
+
+    const handleSave = async () => {
+        if (selectedProvider) {
+            let token = '';
+            if (selectedProvider.name === 'Custom API') {
+                const customObj = {
+                    apiName: customApiName,
+                    baseURL: customBaseURL,
+                    apiKey: customApiKey,
+                    modelName: customModelName
+                };
+                token = JSON.stringify(customObj);
+            } else if (apiKey) {
+                token = apiKey;
+            }
+            if (token) {
+                const success = await setProviderStatus(selectedProvider, token);
+                if (success) {
+                    alert(`Successfully connected to ${selectedProvider.name}`);
+                    if (selectedProvider.name === 'OpenAI') {
+                        mutateOpenAI();
+                    } else if (selectedProvider.name === 'Claude') {
+                        mutateClaude();
+                    } else if (selectedProvider.name === 'Custom API') {
+                        mutateCustom();
+                    }
+                } else {
+                    alert(`Failed to connect to ${selectedProvider.name}`);
+                }
+            }
+        }
+        setDialogOpen(false);
+    };
+
+    return (
+        <Box sx={{ p: 2 }}>
+            <Button onClick={onBack}>Back to Settings</Button>
+            <Typography level="h3" mb={2}>
+                AI Providers
+            </Typography>
+            <List sx={{ gap: 1 }}>
+                {providers.map((provider) => {
+                    const status = getProviderStatus(provider);
+                    const isConnected = status && status !== '';
+                    return (
+                        <ListItem
+                            key={provider.name}
+                            sx={{
+                                display: 'flex',
+                                justifyContent: 'space-between',
+                                alignItems: 'center',
+                                p: 1,
+                                borderRadius: '8px',
+                                bgcolor: 'neutral.softBg'
+                            }}
+                        >
+                            <Typography>{provider.name}</Typography>
+                            {isConnected ? (
+                                <Box
+                                    onMouseEnter={() => setHoveredProvider(provider.name)}
+                                    onMouseLeave={() => setHoveredProvider(null)}
+                                    onClick={() => handleDisconnectClick(provider)}
+                                    sx={{
+                                        cursor: 'pointer',
+                                        border: '1px solid',
+                                        borderColor: 'neutral.outlinedBorder',
+                                        borderRadius: '4px',
+                                        px: 1,
+                                        py: 0.5,
+                                        fontSize: '0.875rem',
+                                        color: 'success.600'
+                                    }}
+                                >
+                                    {hoveredProvider === provider.name ? 'Disconnect' : 'Set up!'}
+                                </Box>
+                            ) : (
+                                <Button variant="soft" onClick={() => handleConnectClick(provider)}>
+                                    Connect
+                                </Button>
+                            )}
+                        </ListItem>
+                    );
+                })}
+            </List>
+            {/* API Key Modal */}
+            <Modal open={dialogOpen} onClose={() => setDialogOpen(false)}>
+                <ModalDialog
+                    layout="stack"
+                    aria-labelledby="connect-dialog-title"
+                    sx={{
+                        position: 'absolute',
+                        top: '50%',
+                        left: '50%',
+                        transform: 'translate(-50%, -50%)',
+                        maxWidth: 400,
+                        width: '90%'
+                    }}
+                >
+                    <Typography id="connect-dialog-title" component="h2">
+                        Connect to {selectedProvider?.name}
+                    </Typography>
+                    {selectedProvider?.name === 'Custom API' ? (
+                        <>
+                            <FormControl sx={{ mt: 2 }}>
+                                <FormLabel>API Name</FormLabel>
+                                <Input
+                                    value={customApiName}
+                                    onChange={(e) => setCustomApiName(e.target.value)}
+                                />
+                            </FormControl>
+                            <FormControl sx={{ mt: 2 }}>
+                                <FormLabel>Base URL</FormLabel>
+                                <Input
+                                    value={customBaseURL}
+                                    onChange={(e) => setCustomBaseURL(e.target.value)}
+                                />
+                            </FormControl>
+                            <FormControl sx={{ mt: 2 }}>
+                                <FormLabel>API Key</FormLabel>
+                                <Input
+                                    type="password"
+                                    value={customApiKey}
+                                    onChange={(e) => setCustomApiKey(e.target.value)}
+                                />
+                            </FormControl>
+                            <FormControl sx={{ mt: 2 }}>
+                                <FormLabel>Model Name</FormLabel>
+                                <Input
+                                    value={customModelName}
+                                    onChange={(e) => setCustomModelName(e.target.value)}
+                                />
+                            </FormControl>
+                        </>
+                    ) : (
+                        <FormControl sx={{ mt: 2 }}>
+                            <FormLabel>{selectedProvider?.name} API Key</FormLabel>
+                            <Input
+                                type="password"
+                                value={apiKey}
+                                onChange={(e) => setApiKey(e.target.value)}
+                            />
+                        </FormControl>
+                    )}
+                    {/* Conditional help steps */}
+                    {selectedProvider?.name === 'OpenAI' && (
+                        <Box sx={{ mt: 2 }}>
+                            <Typography level="body2">
+                                Steps to get an OpenAI API Key:
+                            </Typography>
+                            <ol>
+                                <li>
+                                    Visit{' '}
+                                    <a
+                                        href="https://platform.openai.com/account/api-keys"
+                                        target="_blank"
+                                        rel="noreferrer"
+                                    >
+                                        OpenAI API Keys
+                                    </a>
+                                </li>
+                                <li>Log in to your OpenAI account.</li>
+                                <li>Create a new API key and copy it.</li>
+                            </ol>
+                        </Box>
+                    )}
+                    {selectedProvider?.name === 'Claude' && (
+                        <Box sx={{ mt: 2 }}>
+                            <Typography level="body2">
+                                Steps to get a Claude API Key:
+                            </Typography>
+                            <ol>
+                                <li>Visit your Claude provider’s website.</li>
+                                <li>Log in/create an account.</li>
+                                <li>Follow instructions to generate an API key.</li>
+                            </ol>
+                        </Box>
+                    )}
+                    <Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 1, mt: 2 }}>
+                        <Button onClick={() => setDialogOpen(false)}>Cancel</Button>
+                        <Button onClick={handleSave}>Save</Button>
+                    </Box>
+                </ModalDialog>
+            </Modal>
+        </Box>
+    );
+}
diff --git a/src/renderer/components/TransformerLabSettings.tsx b/src/renderer/components/TransformerLabSettings.tsx
index 054da534..086398b7 100644
--- a/src/renderer/components/TransformerLabSettings.tsx
+++ b/src/renderer/components/TransformerLabSettings.tsx
@@ -1,4 +1,3 @@
-/* eslint-disable jsx-a11y/anchor-is-valid */
 import * as React from 'react';
 
 import Sheet from '@mui/joy/Sheet';
@@ -22,10 +21,12 @@ import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
 import useSWR from 'swr';
 import { EyeIcon, EyeOffIcon, RotateCcwIcon } from 'lucide-react';
 
-const fetcher = (url) => fetch(url).then((res) => res.json());
+// Import the AIProvidersSettings component.
+import AIProvidersSettings from './AIProvidersSettings';
 
+const fetcher = (url) => fetch(url).then((res) => res.json());
 
-export default function TransformerLabSettings({ }) {
+export default function TransformerLabSettings() {
   const [showPassword, setShowPassword] = React.useState(false);
   const {
     data: hftoken,
@@ -37,6 +38,7 @@ export default function TransformerLabSettings({ }) {
     fetcher
   );
   const [showJobsOfType, setShowJobsOfType] = React.useState('NONE');
+  const [showProvidersPage, setShowProvidersPage] = React.useState(false);
 
   const {
     data: jobs,
@@ -52,6 +54,16 @@ export default function TransformerLabSettings({ }) {
     mutate: canLogInToHuggingFaceMutate,
   } = useSWR(chatAPI.Endpoints.Models.HuggingFaceLogin(), fetcher);
 
+  if (showProvidersPage) {
+    return (
+      <AIProvidersSettings
+        onBack={() => {
+          setShowProvidersPage(false);
+        }}
+      />
+    );
+  }
+
   return (
     <>
       <Typography level="h1" marginBottom={3}>
@@ -62,11 +74,10 @@ export default function TransformerLabSettings({ }) {
         <Typography level="title-lg" marginBottom={2}>
           Huggingface Credentials:
         </Typography>
-        {canLogInToHuggingFace?.message == 'OK' ? (
+        {canLogInToHuggingFace?.message === 'OK' ? (
           <Alert color="success">Login to Huggingface Successful</Alert>
         ) : (
           <>
-            {' '}
             <Alert color="danger" sx={{ mb: 1 }}>
               Login to Huggingface Failed. Please set credentials below.
             </Alert>
@@ -82,12 +93,8 @@ export default function TransformerLabSettings({ }) {
                   endDecorator={
                     <IconButton
                       onClick={() => {
-                        var x = document.getElementsByName('hftoken')[0];
-                        if (x.type === 'text') {
-                          x.type = 'password';
-                        } else {
-                          x.type = 'text';
-                        }
+                        const x = document.getElementsByName('hftoken')[0];
+                        x.type = x.type === 'text' ? 'password' : 'text';
                         setShowPassword(!showPassword);
                       }}
                     >
@@ -99,13 +106,8 @@ export default function TransformerLabSettings({ }) {
               <Button
                 onClick={async () => {
                   const token = document.getElementsByName('hftoken')[0].value;
-                  await fetch(
-                    chatAPI.Endpoints.Config.Set(
-                      'HuggingfaceUserAccessToken',
-                      token
-                    )
-                  );
-                  // Now manually log in to huggingface
+                  await fetch(chatAPI.Endpoints.Config.Set('HuggingfaceUserAccessToken', token));
+                  // Now manually log in to Huggingface
                   await fetch(chatAPI.Endpoints.Models.HuggingFaceLogin());
                   hftokenmutate(token);
                   canLogInToHuggingFaceMutate();
@@ -115,21 +117,31 @@ export default function TransformerLabSettings({ }) {
                 Save
               </Button>
               <FormHelperText>
-                A Huggingface access token is required in order to access
-                certain models and datasets (those marked as "Gated").
+                A Huggingface access token is required in order to access certain
+                models and datasets (those marked as "Gated").
               </FormHelperText>
               <FormHelperText>
-                Documentation here:
+                Documentation here:{' '}
                 <a
                   href="https://huggingface.co/docs/hub/security-tokens"
                   target="_blank"
+                  rel="noreferrer"
                 >
                   https://huggingface.co/docs/hub/security-tokens
                 </a>
               </FormHelperText>
             </FormControl>
           </>
-        )}{' '}
+        )}
+        <Divider sx={{ mt: 2, mb: 2 }} />
+        <Typography level="title-lg" marginBottom={2}>
+          Providers & Models:
+        </Typography>
+        {/* Clickable list option */}
+        <Button variant="soft" onClick={() => setShowProvidersPage(true)}>
+          AI Providers and Models
+        </Button>
+        {/* <Divider sx={{ mt: 2, mb: 2 }} />
         <FormControl sx={{ maxWidth: '500px', mt: 2 }}>
           <FormLabel>OpenAI API Key</FormLabel>
           <Input name="openaiKey" type="password" />
@@ -140,11 +152,10 @@ export default function TransformerLabSettings({ }) {
               await fetch(chatAPI.Endpoints.Models.SetOpenAIKey());
               const response = await fetch(chatAPI.Endpoints.Models.CheckOpenAIAPIKey());
               const result = await response.json();
-              if (result.message === "OK") {
-                alert("Successfully set OpenAI API Key");
+              if (result.message === 'OK') {
+                alert('Successfully set OpenAI API Key');
               }
             }}
-
             sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }}
           >
             Save
@@ -160,16 +171,16 @@ export default function TransformerLabSettings({ }) {
               await fetch(chatAPI.Endpoints.Models.SetAnthropicKey());
               const response = await fetch(chatAPI.Endpoints.Models.CheckAnthropicAPIKey());
               const result = await response.json();
-              if (result.message === "OK") {
-                alert("Successfully set Anthropic API Key");
+              if (result.message === 'OK') {
+                alert('Successfully set Anthropic API Key');
               }
             }}
             sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }}
           >
             Save
           </Button>
-        </FormControl>
-        <Divider sx={{ mt: 2, mb: 2 }} />{' '}
+        </FormControl> */}
+        <Divider sx={{ mt: 2, mb: 2 }} />
         <Typography level="title-lg" marginBottom={2}>
           Application:
         </Typography>
@@ -177,7 +188,7 @@ export default function TransformerLabSettings({ }) {
           variant="soft"
           onClick={() => {
             // find and delete all items in local storage that begin with oneTimePopup:
-            for (var key in localStorage) {
+            for (const key in localStorage) {
               if (key.startsWith('oneTimePopup')) {
                 localStorage.removeItem(key);
               }
@@ -186,7 +197,7 @@ export default function TransformerLabSettings({ }) {
         >
           Reset all Tutorial Popup Screens
         </Button>
-        <Divider sx={{ mt: 2, mb: 2 }} />{' '}
+        <Divider sx={{ mt: 2, mb: 2 }} />
         <Typography level="title-lg" marginBottom={2}>
           View Jobs (debug):{' '}
           <IconButton onClick={() => jobsMutate()}>
diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts
index cbf52369..0bc2ddac 100644
--- a/src/renderer/lib/transformerlab-api-sdk.ts
+++ b/src/renderer/lib/transformerlab-api-sdk.ts
@@ -1089,6 +1089,7 @@ Endpoints.Models = {
     modelSource +
     '&model_id=' +
     modelId,
+
   ImportFromLocalPath: (modelPath: string) =>
     API_URL() + 'model/import_from_local_path?model_path=' + modelPath,
   HuggingFaceLogin: () => API_URL() + 'model/login_to_huggingface',
@@ -1097,6 +1098,8 @@ Endpoints.Models = {
   SetAnthropicKey: () => API_URL() + 'model/set_anthropic_api_key',
   CheckOpenAIAPIKey: () => API_URL() + 'model/check_openai_api_key',
   CheckAnthropicAPIKey: () => API_URL() + 'model/check_anthropic_api_key',
+  SetCustomAPIKey: () => API_URL() + 'model/set_custom_api_key',
+  CheckCustomAPIKey: () => API_URL() + 'model/check_custom_api_key',
 };
 
 Endpoints.Plugins = {
-- 
GitLab