diff --git a/src/renderer/components/AIProvidersSettings.tsx b/src/renderer/components/AIProvidersSettings.tsx new file mode 100644 index 0000000000000000000000000000000000000000..5961ff84f87c61811ddff2d8efafe4983c75a60d --- /dev/null +++ b/src/renderer/components/AIProvidersSettings.tsx @@ -0,0 +1,298 @@ +import * as React from 'react'; +import { + Button, + Modal, + ModalDialog, + FormControl, + FormLabel, + Input, + List, + ListItem, + Box, + Typography +} from '@mui/joy'; +import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; +import useSWR from 'swr'; + +const fetcher = (url: string) => fetch(url).then((res) => res.json()); + +interface Provider { + name: string; + keyName: string; + setKeyEndpoint: () => string; + checkKeyEndpoint: () => string; +} + +const providers: Provider[] = [ + { + name: 'OpenAI', + keyName: 'OPENAI_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetOpenAIKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckOpenAIAPIKey(), + }, + { + name: 'Claude', + keyName: 'CLAUDE_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetClaudeKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckClaudeAPIKey(), + }, + { + name: 'Custom API', + keyName: 'CUSTOM_MODEL_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetCustomAPIKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckCustomAPIKey(), + } +]; + +interface AIProvidersSettingsProps { + onBack?: () => void; +} + +export default function AIProvidersSettings({ onBack }: AIProvidersSettingsProps) { + const { data: openaiApiKey, mutate: mutateOpenAI } = useSWR( + chatAPI.Endpoints.Config.Get('OPENAI_API_KEY'), + fetcher + ); + const { data: claudeApiKey, mutate: mutateClaude } = useSWR( + chatAPI.Endpoints.Config.Get('CLAUDE_API_KEY'), + fetcher + ); + const { data: customAPIStatus, mutate: mutateCustom } = useSWR( + chatAPI.Endpoints.Config.Get('CUSTOM_MODEL_API_KEY'), + fetcher + ); + + const getProviderStatus = (provider: Provider) => { + if (provider.name === 'OpenAI') return openaiApiKey; + if (provider.name === 'Claude') return claudeApiKey; + if (provider.name === 'Custom API') return customAPIStatus; + return null; + }; + + const setProviderStatus = async (provider: Provider, token: string) => { + await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, token)); + await fetch(provider.setKeyEndpoint()); + const response = await fetch(provider.checkKeyEndpoint()); + const result = await response.json(); + return result.message === 'OK'; + }; + + const [dialogOpen, setDialogOpen] = React.useState(false); + const [selectedProvider, setSelectedProvider] = React.useState<Provider | null>(null); + const [apiKey, setApiKey] = React.useState(''); + const [hoveredProvider, setHoveredProvider] = React.useState<string | null>(null); + // States for custom API additional fields + const [customApiName, setCustomApiName] = React.useState(''); + const [customBaseURL, setCustomBaseURL] = React.useState(''); + const [customApiKey, setCustomApiKey] = React.useState(''); + const [customModelName, setCustomModelName] = React.useState(''); + + const handleConnectClick = (provider: Provider) => { + setSelectedProvider(provider); + if (provider.name === 'Custom API') { + setCustomApiName(''); + setCustomBaseURL(''); + setCustomApiKey(''); + setCustomModelName(''); + } else { + setApiKey(''); + } + setDialogOpen(true); + }; + + const handleDisconnectClick = async (provider: Provider) => { + await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, '')); + if (provider.name === 'OpenAI') { + mutateOpenAI(); + } else if (provider.name === 'Claude') { + mutateClaude(); + } else if (provider.name === 'Custom API') { + mutateCustom(); + } + }; + + const handleSave = async () => { + if (selectedProvider) { + let token = ''; + if (selectedProvider.name === 'Custom API') { + const customObj = { + apiName: customApiName, + baseURL: customBaseURL, + apiKey: customApiKey, + modelName: customModelName + }; + token = JSON.stringify(customObj); + } else if (apiKey) { + token = apiKey; + } + if (token) { + const success = await setProviderStatus(selectedProvider, token); + if (success) { + alert(`Successfully connected to ${selectedProvider.name}`); + if (selectedProvider.name === 'OpenAI') { + mutateOpenAI(); + } else if (selectedProvider.name === 'Claude') { + mutateClaude(); + } else if (selectedProvider.name === 'Custom API') { + mutateCustom(); + } + } else { + alert(`Failed to connect to ${selectedProvider.name}`); + } + } + } + setDialogOpen(false); + }; + + return ( + <Box sx={{ p: 2 }}> + <Button onClick={onBack}>Back to Settings</Button> + <Typography level="h3" mb={2}> + AI Providers + </Typography> + <List sx={{ gap: 1 }}> + {providers.map((provider) => { + const status = getProviderStatus(provider); + const isConnected = status && status !== ''; + return ( + <ListItem + key={provider.name} + sx={{ + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + p: 1, + borderRadius: '8px', + bgcolor: 'neutral.softBg' + }} + > + <Typography>{provider.name}</Typography> + {isConnected ? ( + <Box + onMouseEnter={() => setHoveredProvider(provider.name)} + onMouseLeave={() => setHoveredProvider(null)} + onClick={() => handleDisconnectClick(provider)} + sx={{ + cursor: 'pointer', + border: '1px solid', + borderColor: 'neutral.outlinedBorder', + borderRadius: '4px', + px: 1, + py: 0.5, + fontSize: '0.875rem', + color: 'success.600' + }} + > + {hoveredProvider === provider.name ? 'Disconnect' : 'Set up!'} + </Box> + ) : ( + <Button variant="soft" onClick={() => handleConnectClick(provider)}> + Connect + </Button> + )} + </ListItem> + ); + })} + </List> + {/* API Key Modal */} + <Modal open={dialogOpen} onClose={() => setDialogOpen(false)}> + <ModalDialog + layout="stack" + aria-labelledby="connect-dialog-title" + sx={{ + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + maxWidth: 400, + width: '90%' + }} + > + <Typography id="connect-dialog-title" component="h2"> + Connect to {selectedProvider?.name} + </Typography> + {selectedProvider?.name === 'Custom API' ? ( + <> + <FormControl sx={{ mt: 2 }}> + <FormLabel>API Name</FormLabel> + <Input + value={customApiName} + onChange={(e) => setCustomApiName(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>Base URL</FormLabel> + <Input + value={customBaseURL} + onChange={(e) => setCustomBaseURL(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>API Key</FormLabel> + <Input + type="password" + value={customApiKey} + onChange={(e) => setCustomApiKey(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>Model Name</FormLabel> + <Input + value={customModelName} + onChange={(e) => setCustomModelName(e.target.value)} + /> + </FormControl> + </> + ) : ( + <FormControl sx={{ mt: 2 }}> + <FormLabel>{selectedProvider?.name} API Key</FormLabel> + <Input + type="password" + value={apiKey} + onChange={(e) => setApiKey(e.target.value)} + /> + </FormControl> + )} + {/* Conditional help steps */} + {selectedProvider?.name === 'OpenAI' && ( + <Box sx={{ mt: 2 }}> + <Typography level="body2"> + Steps to get an OpenAI API Key: + </Typography> + <ol> + <li> + Visit{' '} + <a + href="https://platform.openai.com/account/api-keys" + target="_blank" + rel="noreferrer" + > + OpenAI API Keys + </a> + </li> + <li>Log in to your OpenAI account.</li> + <li>Create a new API key and copy it.</li> + </ol> + </Box> + )} + {selectedProvider?.name === 'Claude' && ( + <Box sx={{ mt: 2 }}> + <Typography level="body2"> + Steps to get a Claude API Key: + </Typography> + <ol> + <li>Visit your Claude provider’s website.</li> + <li>Log in/create an account.</li> + <li>Follow instructions to generate an API key.</li> + </ol> + </Box> + )} + <Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 1, mt: 2 }}> + <Button onClick={() => setDialogOpen(false)}>Cancel</Button> + <Button onClick={handleSave}>Save</Button> + </Box> + </ModalDialog> + </Modal> + </Box> + ); +} diff --git a/src/renderer/components/TransformerLabSettings.tsx b/src/renderer/components/TransformerLabSettings.tsx index 054da53480eb795ce73fd6a15c3056f8ae116440..086398b78578f88961dc7e11b52f4bca341272ce 100644 --- a/src/renderer/components/TransformerLabSettings.tsx +++ b/src/renderer/components/TransformerLabSettings.tsx @@ -1,4 +1,3 @@ -/* eslint-disable jsx-a11y/anchor-is-valid */ import * as React from 'react'; import Sheet from '@mui/joy/Sheet'; @@ -22,10 +21,12 @@ import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; import useSWR from 'swr'; import { EyeIcon, EyeOffIcon, RotateCcwIcon } from 'lucide-react'; -const fetcher = (url) => fetch(url).then((res) => res.json()); +// Import the AIProvidersSettings component. +import AIProvidersSettings from './AIProvidersSettings'; +const fetcher = (url) => fetch(url).then((res) => res.json()); -export default function TransformerLabSettings({ }) { +export default function TransformerLabSettings() { const [showPassword, setShowPassword] = React.useState(false); const { data: hftoken, @@ -37,6 +38,7 @@ export default function TransformerLabSettings({ }) { fetcher ); const [showJobsOfType, setShowJobsOfType] = React.useState('NONE'); + const [showProvidersPage, setShowProvidersPage] = React.useState(false); const { data: jobs, @@ -52,6 +54,16 @@ export default function TransformerLabSettings({ }) { mutate: canLogInToHuggingFaceMutate, } = useSWR(chatAPI.Endpoints.Models.HuggingFaceLogin(), fetcher); + if (showProvidersPage) { + return ( + <AIProvidersSettings + onBack={() => { + setShowProvidersPage(false); + }} + /> + ); + } + return ( <> <Typography level="h1" marginBottom={3}> @@ -62,11 +74,10 @@ export default function TransformerLabSettings({ }) { <Typography level="title-lg" marginBottom={2}> Huggingface Credentials: </Typography> - {canLogInToHuggingFace?.message == 'OK' ? ( + {canLogInToHuggingFace?.message === 'OK' ? ( <Alert color="success">Login to Huggingface Successful</Alert> ) : ( <> - {' '} <Alert color="danger" sx={{ mb: 1 }}> Login to Huggingface Failed. Please set credentials below. </Alert> @@ -82,12 +93,8 @@ export default function TransformerLabSettings({ }) { endDecorator={ <IconButton onClick={() => { - var x = document.getElementsByName('hftoken')[0]; - if (x.type === 'text') { - x.type = 'password'; - } else { - x.type = 'text'; - } + const x = document.getElementsByName('hftoken')[0]; + x.type = x.type === 'text' ? 'password' : 'text'; setShowPassword(!showPassword); }} > @@ -99,13 +106,8 @@ export default function TransformerLabSettings({ }) { <Button onClick={async () => { const token = document.getElementsByName('hftoken')[0].value; - await fetch( - chatAPI.Endpoints.Config.Set( - 'HuggingfaceUserAccessToken', - token - ) - ); - // Now manually log in to huggingface + await fetch(chatAPI.Endpoints.Config.Set('HuggingfaceUserAccessToken', token)); + // Now manually log in to Huggingface await fetch(chatAPI.Endpoints.Models.HuggingFaceLogin()); hftokenmutate(token); canLogInToHuggingFaceMutate(); @@ -115,21 +117,31 @@ export default function TransformerLabSettings({ }) { Save </Button> <FormHelperText> - A Huggingface access token is required in order to access - certain models and datasets (those marked as "Gated"). + A Huggingface access token is required in order to access certain + models and datasets (those marked as "Gated"). </FormHelperText> <FormHelperText> - Documentation here: + Documentation here:{' '} <a href="https://huggingface.co/docs/hub/security-tokens" target="_blank" + rel="noreferrer" > https://huggingface.co/docs/hub/security-tokens </a> </FormHelperText> </FormControl> </> - )}{' '} + )} + <Divider sx={{ mt: 2, mb: 2 }} /> + <Typography level="title-lg" marginBottom={2}> + Providers & Models: + </Typography> + {/* Clickable list option */} + <Button variant="soft" onClick={() => setShowProvidersPage(true)}> + AI Providers and Models + </Button> + {/* <Divider sx={{ mt: 2, mb: 2 }} /> <FormControl sx={{ maxWidth: '500px', mt: 2 }}> <FormLabel>OpenAI API Key</FormLabel> <Input name="openaiKey" type="password" /> @@ -140,11 +152,10 @@ export default function TransformerLabSettings({ }) { await fetch(chatAPI.Endpoints.Models.SetOpenAIKey()); const response = await fetch(chatAPI.Endpoints.Models.CheckOpenAIAPIKey()); const result = await response.json(); - if (result.message === "OK") { - alert("Successfully set OpenAI API Key"); + if (result.message === 'OK') { + alert('Successfully set OpenAI API Key'); } }} - sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }} > Save @@ -160,16 +171,16 @@ export default function TransformerLabSettings({ }) { await fetch(chatAPI.Endpoints.Models.SetAnthropicKey()); const response = await fetch(chatAPI.Endpoints.Models.CheckAnthropicAPIKey()); const result = await response.json(); - if (result.message === "OK") { - alert("Successfully set Anthropic API Key"); + if (result.message === 'OK') { + alert('Successfully set Anthropic API Key'); } }} sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }} > Save </Button> - </FormControl> - <Divider sx={{ mt: 2, mb: 2 }} />{' '} + </FormControl> */} + <Divider sx={{ mt: 2, mb: 2 }} /> <Typography level="title-lg" marginBottom={2}> Application: </Typography> @@ -177,7 +188,7 @@ export default function TransformerLabSettings({ }) { variant="soft" onClick={() => { // find and delete all items in local storage that begin with oneTimePopup: - for (var key in localStorage) { + for (const key in localStorage) { if (key.startsWith('oneTimePopup')) { localStorage.removeItem(key); } @@ -186,7 +197,7 @@ export default function TransformerLabSettings({ }) { > Reset all Tutorial Popup Screens </Button> - <Divider sx={{ mt: 2, mb: 2 }} />{' '} + <Divider sx={{ mt: 2, mb: 2 }} /> <Typography level="title-lg" marginBottom={2}> View Jobs (debug):{' '} <IconButton onClick={() => jobsMutate()}> diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts index cbf523690bf68daa602222c7ccb62cd4682d1823..0bc2ddac6857b240841553ceb2d65547ea1422f5 100644 --- a/src/renderer/lib/transformerlab-api-sdk.ts +++ b/src/renderer/lib/transformerlab-api-sdk.ts @@ -1089,6 +1089,7 @@ Endpoints.Models = { modelSource + '&model_id=' + modelId, + ImportFromLocalPath: (modelPath: string) => API_URL() + 'model/import_from_local_path?model_path=' + modelPath, HuggingFaceLogin: () => API_URL() + 'model/login_to_huggingface', @@ -1097,6 +1098,8 @@ Endpoints.Models = { SetAnthropicKey: () => API_URL() + 'model/set_anthropic_api_key', CheckOpenAIAPIKey: () => API_URL() + 'model/check_openai_api_key', CheckAnthropicAPIKey: () => API_URL() + 'model/check_anthropic_api_key', + SetCustomAPIKey: () => API_URL() + 'model/set_custom_api_key', + CheckCustomAPIKey: () => API_URL() + 'model/check_custom_api_key', }; Endpoints.Plugins = {