diff --git a/src/renderer/components/Experiment/DynamicPluginForm.tsx b/src/renderer/components/Experiment/DynamicPluginForm.tsx index e11e8e31c7b342a6c72e47e4f2a26b701a1f1c3a..d564f26440f8e0f1d7f19317b141088c169966b8 100644 --- a/src/renderer/components/Experiment/DynamicPluginForm.tsx +++ b/src/renderer/components/Experiment/DynamicPluginForm.tsx @@ -27,6 +27,7 @@ import { Textarea, } from '@mui/joy'; import { useMemo } from 'react'; +import ModelProviderWidget from 'renderer/components/Experiment/Widgets/ModelProviderWidget'; import { RegistryWidgetsType, @@ -389,6 +390,7 @@ function CustomSelectSimple< ); } + function CustomAutocompleteWidget<T = any, S extends StrictRJSFSchema = RJSFSchema, F extends FormContextType = any>( props: WidgetProps<T, S, F> ) { @@ -406,7 +408,6 @@ function CustomAutocompleteWidget<T = any, S extends StrictRJSFSchema = RJSFSche } = props; const { enumOptions } = options; - // Default multiple is true. // const _multiple = typeof multiple === 'undefined' ? true : !!multiple; // Check both multiple and options.multiple; default is true. @@ -596,6 +597,7 @@ const widgets: RegistryWidgetsType = { SelectWidget: CustomSelectSimple, AutoCompleteWidget: CustomAutocompleteWidget, EvaluationWidget: CustomEvaluationWidget, + ModelProviderWidget: ModelProviderWidget }; const fetcher = (url) => fetch(url).then((res) => res.json()); diff --git a/src/renderer/components/Experiment/Generate/Generate.tsx b/src/renderer/components/Experiment/Generate/Generate.tsx index cddd739140afb8bd3a8e798c230ce47b35da6817..aa3f3769f681b9bfa835814c01fb942414a16b26 100644 --- a/src/renderer/components/Experiment/Generate/Generate.tsx +++ b/src/renderer/components/Experiment/Generate/Generate.tsx @@ -18,7 +18,7 @@ import { } from '@mui/joy'; import { PlusCircleIcon } from 'lucide-react'; -import GenerateJobsTable from './GenerateJobsTable.tsx'; +import GenerateJobsTable from './GenerateJobsTable'; import GenerateTasksTable from './GenerateTasksTable'; import GenerateModal from './GenerateModal'; @@ -68,7 +68,11 @@ export default function Generate({ if (value) { // Use fetch to post the value to the server await fetch( - chatAPI.Endpoints.Experiment.SavePlugin(project, generationName, 'main.py'), + chatAPI.Endpoints.Experiment.SavePlugin( + project, + generationName, + 'main.py' + ), { method: 'POST', body: value, @@ -85,7 +89,6 @@ export default function Generate({ if (!experimentInfo) { return 'No experiment selected'; } - console.log('Experiment Info:', experimentInfo); return ( <> @@ -122,8 +125,7 @@ export default function Generate({ </Typography> {plugins?.length === 0 ? ( <Alert color="danger"> - No Generator Scripts available, please install a generator - plugin. + No Generator Scripts available, please install a generator plugin. </Alert> ) : ( <Dropdown> diff --git a/src/renderer/components/Experiment/Generate/GenerateModal.tsx b/src/renderer/components/Experiment/Generate/GenerateModal.tsx index d7469eda5ddc7353c6b90b6c523500f41555f3a7..2a1c25c5c3c6dc6cf3cf431c442c0be1615a9ec6 100644 --- a/src/renderer/components/Experiment/Generate/GenerateModal.tsx +++ b/src/renderer/components/Experiment/Generate/GenerateModal.tsx @@ -49,6 +49,30 @@ function PluginIntroduction({ experimentInfo, pluginId }) { ); } +/* This function looks at all the generations that are stored in the experiment JSON +and returns the generation that matches the generationName */ +function getGenerationFromGenerationsArray(generationsStr, generationName) { + let thisGeneration = null; + console.log(generationName); + + if (typeof generationsStr === 'string') { + try { + const generations = JSON.parse(generationsStr); + console.log('generations:', generations); + + if (Array.isArray(generations)) { + thisGeneration = generations.find( + (generation) => generation.name === generationName + ); + } + } catch (error) { + console.error('Failed to parse generations JSON string:', error); + } + } + console.log('thisGeneration', thisGeneration); + return thisGeneration; +} + export default function GenerateModal({ open, onClose, @@ -119,78 +143,71 @@ export default function GenerateModal({ if (currentGenerationName && currentGenerationName !== '') { const generationsStr = experimentInfo.config?.generations; setSelectedDocs([]); - if (typeof generationsStr === 'string') { - try { - const generations = JSON.parse(generationsStr); - if (Array.isArray(generations)) { - const generationConfig = generations.find( - (generationItem: any) => - generationItem.name === currentGenerationName && - generationItem.plugin === pluginId - ); - if (generationConfig) { - setConfig(generationConfig.script_parameters); - const datasetKeyExists = Object.keys( - generationConfig.script_parameters - ).some((key) => key.toLowerCase().includes('dataset')); + if (generationsStr) { + const generationConfig = getGenerationFromGenerationsArray( + generationsStr, + currentGenerationName + ); + if (generationConfig) { + setConfig(generationConfig.script_parameters); - const docsKeyExists = Object.keys( - generationConfig.script_parameters - ).some((key) => key.toLowerCase().includes('docs')); + const datasetKeyExists = Object.keys( + generationConfig.script_parameters + ).some((key) => key.toLowerCase().includes('dataset')); - const contextKeyExists = Object.keys( - generationConfig.script_parameters - ).some((key) => key.toLowerCase().includes('context')); + const docsKeyExists = Object.keys( + generationConfig.script_parameters + ).some((key) => key.toLowerCase().includes('docs')); - setHasDatasetKey(datasetKeyExists); + const contextKeyExists = Object.keys( + generationConfig.script_parameters + ).some((key) => key.toLowerCase().includes('context')); - if ( - docsKeyExists && - generationConfig.script_parameters.docs.length > 0 - ) { - setHasContextKey(false); - setHasDocumentsKey(true); - generationConfig.script_parameters.docs = - generationConfig.script_parameters.docs.split(','); - setConfig(generationConfig.script_parameters); - setSelectedDocs(generationConfig.script_parameters.docs); - } else if ( - contextKeyExists && - generationConfig.script_parameters.context.length > 0 - ) { - setHasContextKey(true); - setHasDocumentsKey(false); - const context = generationConfig.script_parameters.context; - setContextInput(context); - delete generationConfig.script_parameters.context; - setConfig(generationConfig.script_parameters); - } + setHasDatasetKey(datasetKeyExists); - if ( - hasDatasetKey && - generationConfig.script_parameters.dataset_name.length > 0 - ) { - setSelectedDataset( - generationConfig.script_parameters.dataset_name - ); - } - if ( - generationConfig.script_parameters._dataset_display_message && - generationConfig.script_parameters._dataset_display_message - .length > 0 - ) { - setDatasetDisplayMessage( - generationConfig.script_parameters._dataset_display_message - ); - } - if (!nameInput && generationConfig?.name.length > 0) { - setNameInput(generationConfig.name); - } - } + if ( + docsKeyExists && + generationConfig.script_parameters.docs.length > 0 + ) { + setHasContextKey(false); + setHasDocumentsKey(true); + generationConfig.script_parameters.docs = + generationConfig.script_parameters.docs.split(','); + setConfig(generationConfig.script_parameters); + setSelectedDocs(generationConfig.script_parameters.docs); + } else if ( + contextKeyExists && + generationConfig.script_parameters.context.length > 0 + ) { + setHasContextKey(true); + setHasDocumentsKey(false); + const context = generationConfig.script_parameters.context; + setContextInput(context); + delete generationConfig.script_parameters.context; + setConfig(generationConfig.script_parameters); + } + + if ( + hasDatasetKey && + generationConfig.script_parameters.dataset_name.length > 0 + ) { + setSelectedDataset( + generationConfig.script_parameters.dataset_name + ); + } + if ( + generationConfig.script_parameters._dataset_display_message && + generationConfig.script_parameters._dataset_display_message + .length > 0 + ) { + setDatasetDisplayMessage( + generationConfig.script_parameters._dataset_display_message + ); + } + if (!nameInput && generationConfig?.name.length > 0) { + setNameInput(generationConfig.name); } - } catch (error) { - console.error('Failed to parse generations JSON string:', error); } } } else { diff --git a/src/renderer/components/Experiment/Widgets/ModelProviderWidget.tsx b/src/renderer/components/Experiment/Widgets/ModelProviderWidget.tsx new file mode 100644 index 0000000000000000000000000000000000000000..48003b621255bec53d23d67c3954c2766ce47656 --- /dev/null +++ b/src/renderer/components/Experiment/Widgets/ModelProviderWidget.tsx @@ -0,0 +1,156 @@ +import * as React from 'react'; +import useSWR from 'swr'; +import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; +import { Autocomplete } from '@mui/joy'; +import { + WidgetProps, + RJSFSchema, + StrictRJSFSchema, + FormContextType, +} from '@rjsf/utils'; + +// Simple fetcher for useSWR. +const fetcher = (url: string) => fetch(url).then((res) => res.json()); + +function ModelProviderWidget< + T = any, + S extends StrictRJSFSchema = RJSFSchema, + F extends FormContextType = any +>(props: WidgetProps<T, S, F>) { + const { + id, + value, + required, + disabled, + readonly, + autofocus, + onChange, + options, + schema, + multiple, + } = props; + + // Determine multiple, defaulting to true. + const _multiple = + typeof multiple !== 'undefined' + ? Boolean(multiple) + : typeof options.multiple !== 'undefined' + ? Boolean(options.multiple) + : true; + + // Disabled API key mapping. + const isDisabledFilter = true; + const disabledEnvMap = { + claude: 'ANTHROPIC_API_KEY', + openai: 'OPENAI_API_KEY', + custom: 'CUSTOM_MODEL_API_KEY', + }; + const configKeysInOrder = Object.values(disabledEnvMap); + const configResults = configKeysInOrder.map((key) => + useSWR(chatAPI.Endpoints.Config.Get(key), fetcher) + ); + const configValues = React.useMemo(() => { + const values: Record<string, any> = {}; + configKeysInOrder.forEach((key, idx) => { + values[key] = configResults[idx]?.data; + }); + return values; + }, [configKeysInOrder, configResults]); + + // Map: label => stored value. + const labelToCustomValue: Record<string, string> = { + 'Claude 3.7 Sonnet': 'claude-3-7-sonnet-latest', + 'Claude 3.5 Haiku': 'claude-3-5-haiku-latest', + 'OpenAI GPT 4o': 'gpt-4o', + 'OpenAI GPT 4o Mini': 'gpt-4o-mini', + 'Custom Model API': 'custom-model-api', + 'Local': 'local', + }; + + // Options coming from mapping keys. + const optionsList = Object.keys(labelToCustomValue); + + // Inverse mapping: stored value => label. + const customValueToLabel = Object.entries(labelToCustomValue).reduce( + (acc, [label, custom]) => { + acc[custom] = label; + return acc; + }, + {} as Record<string, string> + ); + + // Set default/current value. + const defaultValue = _multiple ? [] : ''; + const currentValue = value !== undefined ? value : defaultValue; + + // Convert stored value(s) to display labels. + const displayValue = _multiple + ? Array.isArray(currentValue) + ? currentValue.map((val) => customValueToLabel[val] || val) + : [] + : customValueToLabel[currentValue] || currentValue; + + // Build disabled mapping for options. + const combinedOptions = optionsList.reduce( + (acc: Record<string, { disabled: boolean; info?: string }>, opt) => { + const lower = opt.toLowerCase(); + let optDisabled = false; + let infoMessage = ''; + if (isDisabledFilter) { + for (const envKey in disabledEnvMap) { + if (lower.startsWith(envKey)) { + const configKey = disabledEnvMap[envKey]; + const configVal = configValues[configKey]; + optDisabled = !configVal || configVal === ''; + if (optDisabled) { + infoMessage = `Please set ${configKey} in settings`; + } + break; + } + } + } + acc[opt] = { disabled: optDisabled, info: optDisabled ? infoMessage : '' }; + return acc; + }, + {} + ); + + return ( + <> + <Autocomplete + multiple={_multiple} + id={id} + placeholder={schema.title || ''} + options={optionsList} + getOptionLabel={(option) => + option + + (combinedOptions[option]?.disabled ? ' - ' + combinedOptions[option].info : '') + } + getOptionDisabled={(option) => combinedOptions[option]?.disabled ?? false} + value={displayValue} + onChange={(event, newValue) => { + const storedValue = _multiple + ? newValue.map((label) => labelToCustomValue[label] || label) + : (labelToCustomValue[newValue] || newValue); + onChange(storedValue); + }} + disabled={disabled || readonly} + autoFocus={autofocus} + /> + {/* Hidden input to capture the stored value on form submission */} + <input + type="hidden" + name={id} + value={ + _multiple + ? Array.isArray(currentValue) + ? currentValue.join(',') + : currentValue + : currentValue + } + /> + </> + ); +} + +export default ModelProviderWidget; diff --git a/src/renderer/components/MainAppPanel.tsx b/src/renderer/components/MainAppPanel.tsx index dfc01ca85a692de1655f0c1c6c9029bb1e4a7324..29c8469ce937323b062342070d700251d2ba03b2 100644 --- a/src/renderer/components/MainAppPanel.tsx +++ b/src/renderer/components/MainAppPanel.tsx @@ -25,7 +25,7 @@ import Tokenize from './Experiment/Interact/Tokenize'; import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; import ExperimentNotes from './Experiment/ExperimentNotes'; -import TransformerLabSettings from './TransformerLabSettings'; +import TransformerLabSettings from './Settings/TransformerLabSettings'; import Logs from './Logs'; import FoundationHome from './Experiment/Foundation'; import { useState } from 'react'; diff --git a/src/renderer/components/Settings/AIProvidersSettings.tsx b/src/renderer/components/Settings/AIProvidersSettings.tsx new file mode 100644 index 0000000000000000000000000000000000000000..f52d351a960ae4abbfac5881f2f32fa0a0590961 --- /dev/null +++ b/src/renderer/components/Settings/AIProvidersSettings.tsx @@ -0,0 +1,305 @@ +import * as React from 'react'; +import { + Button, + Modal, + ModalDialog, + FormControl, + FormLabel, + Input, + List, + ListItem, + Box, + Typography +} from '@mui/joy'; +import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; +import useSWR from 'swr'; + +const fetcher = (url: string) => fetch(url).then((res) => res.json()); + +interface Provider { + name: string; + keyName: string; + setKeyEndpoint: () => string; + checkKeyEndpoint: () => string; +} + +const providers: Provider[] = [ + { + name: 'OpenAI', + keyName: 'OPENAI_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetOpenAIKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckOpenAIAPIKey(), + }, + { + name: 'Anthropic', + keyName: 'ANTHROPIC_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetAnthropicKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckAnthropicAPIKey(), + }, + { + name: 'Custom API', + keyName: 'CUSTOM_MODEL_API_KEY', + setKeyEndpoint: () => chatAPI.Endpoints.Models.SetCustomAPIKey(), + checkKeyEndpoint: () => chatAPI.Endpoints.Models.CheckCustomAPIKey(), + } +]; + +interface AIProvidersSettingsProps { + onBack?: () => void; +} + +export default function AIProvidersSettings({ onBack }: AIProvidersSettingsProps) { + const { data: openaiApiKey, mutate: mutateOpenAI } = useSWR( + chatAPI.Endpoints.Config.Get('OPENAI_API_KEY'), + fetcher + ); + const { data: claudeApiKey, mutate: mutateClaude } = useSWR( + chatAPI.Endpoints.Config.Get('ANTHROPIC_API_KEY'), + fetcher + ); + const { data: customAPIStatus, mutate: mutateCustom } = useSWR( + chatAPI.Endpoints.Config.Get('CUSTOM_MODEL_API_KEY'), + fetcher + ); + + const getProviderStatus = (provider: Provider) => { + if (provider.name === 'OpenAI') return openaiApiKey; + if (provider.name === 'Anthropic') return claudeApiKey; + if (provider.name === 'Custom API') return customAPIStatus; + return null; + }; + + const setProviderStatus = async (provider: Provider, token: string) => { + await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, token)); + await fetch(provider.setKeyEndpoint()); + const response = await fetch(provider.checkKeyEndpoint()); + const result = await response.json(); + return result.message === 'OK'; + }; + + const [dialogOpen, setDialogOpen] = React.useState(false); + const [selectedProvider, setSelectedProvider] = React.useState<Provider | null>(null); + const [apiKey, setApiKey] = React.useState(''); + const [hoveredProvider, setHoveredProvider] = React.useState<string | null>(null); + // States for custom API additional fields + const [customApiName, setCustomApiName] = React.useState(''); + const [customBaseURL, setCustomBaseURL] = React.useState(''); + const [customApiKey, setCustomApiKey] = React.useState(''); + const [customModelName, setCustomModelName] = React.useState(''); + + const handleConnectClick = (provider: Provider) => { + setSelectedProvider(provider); + if (provider.name === 'Custom API') { + setCustomApiName(''); + setCustomBaseURL(''); + setCustomApiKey(''); + setCustomModelName(''); + } else { + setApiKey(''); + } + setDialogOpen(true); + }; + + const handleDisconnectClick = async (provider: Provider) => { + await fetch(chatAPI.Endpoints.Config.Set(provider.keyName, '')); + if (provider.name === 'OpenAI') { + mutateOpenAI(); + } else if (provider.name === 'Anthropic') { + mutateClaude(); + } else if (provider.name === 'Custom API') { + mutateCustom(); + } + }; + + const handleSave = async () => { + if (selectedProvider) { + let token = ''; + if (selectedProvider.name === 'Custom API') { + const customObj = { + apiName: customApiName, + baseURL: customBaseURL, + apiKey: customApiKey, + modelName: customModelName + }; + token = JSON.stringify(customObj); + } else if (apiKey) { + token = apiKey; + } + if (token) { + const success = await setProviderStatus(selectedProvider, token); + if (success) { + alert(`Successfully connected to ${selectedProvider.name}`); + if (selectedProvider.name === 'OpenAI') { + mutateOpenAI(); + } else if (selectedProvider.name === 'Anthropic') { + mutateClaude(); + } else if (selectedProvider.name === 'Custom API') { + mutateCustom(); + } + } else { + alert(`Failed to connect to ${selectedProvider.name}`); + } + } + } + setDialogOpen(false); + }; + + return ( + <Box sx={{ p: 2 }}> + <Button onClick={onBack}>Back to Settings</Button> + <Typography level="h3" mb={2}> + AI Providers + </Typography> + <List sx={{ gap: 1 }}> + {providers.map((provider) => { + const status = getProviderStatus(provider); + const isConnected = status && status !== ''; + return ( + <ListItem + key={provider.name} + sx={{ + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + p: 1, + borderRadius: '8px', + bgcolor: 'neutral.softBg' + }} + > + <Typography>{provider.name}</Typography> + {isConnected ? ( + <Box + onMouseEnter={() => setHoveredProvider(provider.name)} + onMouseLeave={() => setHoveredProvider(null)} + onClick={() => handleDisconnectClick(provider)} + sx={{ + cursor: 'pointer', + border: '1px solid', + borderColor: 'neutral.outlinedBorder', + borderRadius: '4px', + px: 1, + py: 0.5, + fontSize: '0.875rem', + color: 'success.600' + }} + > + {hoveredProvider === provider.name ? 'Disconnect' : 'Connected'} + </Box> + ) : ( + <Button variant="soft" onClick={() => handleConnectClick(provider)}> + Connect + </Button> + )} + </ListItem> + ); + })} + </List> + {/* API Key Modal */} + <Modal open={dialogOpen} onClose={() => setDialogOpen(false)}> + <ModalDialog + layout="stack" + aria-labelledby="connect-dialog-title" + sx={{ + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + maxWidth: 400, + width: '90%' + }} + > + <Typography id="connect-dialog-title" component="h2"> + Connect to {selectedProvider?.name} + </Typography> + {selectedProvider?.name === 'Custom API' ? ( + <> + <FormControl sx={{ mt: 2 }}> + <FormLabel>API Name</FormLabel> + <Input + value={customApiName} + onChange={(e) => setCustomApiName(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>Base URL</FormLabel> + <Input + value={customBaseURL} + onChange={(e) => setCustomBaseURL(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>API Key</FormLabel> + <Input + type="password" + value={customApiKey} + onChange={(e) => setCustomApiKey(e.target.value)} + /> + </FormControl> + <FormControl sx={{ mt: 2 }}> + <FormLabel>Model Name</FormLabel> + <Input + value={customModelName} + onChange={(e) => setCustomModelName(e.target.value)} + /> + </FormControl> + </> + ) : ( + <FormControl sx={{ mt: 2 }}> + <FormLabel>{selectedProvider?.name} API Key</FormLabel> + <Input + type="password" + value={apiKey} + onChange={(e) => setApiKey(e.target.value)} + /> + </FormControl> + )} + {/* Conditional help steps */} + {selectedProvider?.name === 'OpenAI' && ( + <Box sx={{ mt: 2 }}> + <Typography level="body2"> + Steps to get an OpenAI API Key: + </Typography> + <ol> + <li> + Visit{' '} + <a + href="https://platform.openai.com/account/api-keys" + target="_blank" + rel="noreferrer" + > + OpenAI API website + </a> + </li> + <li>Log in to your OpenAI account.</li> + <li>Create a new API key and copy it.</li> + </ol> + </Box> + )} + {selectedProvider?.name === 'Anthropic' && ( + <Box sx={{ mt: 2 }}> + <Typography level="body2"> + Steps to get a Anthropic API Key: + </Typography> + <ol> + <li>Visit{' '} + <a + href="https://console.anthropic.com/settings/keys" + target="_blank" + rel="noreferrer" + > + Anthropic API Keys Console + </a></li> + <li>Log in/Create an account.</li> + <li>Follow instructions to generate an API key.</li> + </ol> + </Box> + )} + <Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 1, mt: 2 }}> + <Button onClick={() => setDialogOpen(false)}>Cancel</Button> + <Button onClick={handleSave}>Save</Button> + </Box> + </ModalDialog> + </Modal> + </Box> + ); +} diff --git a/src/renderer/components/TransformerLabSettings.tsx b/src/renderer/components/Settings/TransformerLabSettings.tsx similarity index 70% rename from src/renderer/components/TransformerLabSettings.tsx rename to src/renderer/components/Settings/TransformerLabSettings.tsx index af116b035759b5a95cf021d9c0d79e48fcdf616e..0e2654a2e5b653fe649f626c1682a45597417b06 100644 --- a/src/renderer/components/TransformerLabSettings.tsx +++ b/src/renderer/components/Settings/TransformerLabSettings.tsx @@ -1,4 +1,3 @@ -/* eslint-disable jsx-a11y/anchor-is-valid */ import * as React from 'react'; import Sheet from '@mui/joy/Sheet'; @@ -22,10 +21,12 @@ import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; import useSWR from 'swr'; import { EyeIcon, EyeOffIcon, RotateCcwIcon } from 'lucide-react'; -const fetcher = (url) => fetch(url).then((res) => res.json()); +// Import the AIProvidersSettings component. +import AIProvidersSettings from './AIProvidersSettings'; +const fetcher = (url) => fetch(url).then((res) => res.json()); -export default function TransformerLabSettings({ }) { +export default function TransformerLabSettings() { const [showPassword, setShowPassword] = React.useState(false); const { data: hftoken, @@ -37,6 +38,7 @@ export default function TransformerLabSettings({ }) { fetcher ); const [showJobsOfType, setShowJobsOfType] = React.useState('NONE'); + const [showProvidersPage, setShowProvidersPage] = React.useState(false); const { data: jobs, @@ -59,6 +61,18 @@ export default function TransformerLabSettings({ }) { mutate: wandbLoginMutate, } = useSWR(chatAPI.Endpoints.Models.testWandbLogin(), fetcher); + if (showProvidersPage) { + return ( + <AIProvidersSettings + onBack={() => { + setShowProvidersPage(false); + }} + /> + ); + } + + + return ( <> <Typography level="h1" marginBottom={3}> @@ -69,11 +83,10 @@ export default function TransformerLabSettings({ }) { <Typography level="title-lg" marginBottom={2}> Huggingface Credentials: </Typography> - {canLogInToHuggingFace?.message == 'OK' ? ( + {canLogInToHuggingFace?.message === 'OK' ? ( <Alert color="success">Login to Huggingface Successful</Alert> ) : ( <> - {' '} <Alert color="danger" sx={{ mb: 1 }}> Login to Huggingface Failed. Please set credentials below. </Alert> @@ -89,12 +102,8 @@ export default function TransformerLabSettings({ }) { endDecorator={ <IconButton onClick={() => { - var x = document.getElementsByName('hftoken')[0]; - if (x.type === 'text') { - x.type = 'password'; - } else { - x.type = 'text'; - } + const x = document.getElementsByName('hftoken')[0]; + x.type = x.type === 'text' ? 'password' : 'text'; setShowPassword(!showPassword); }} > @@ -106,13 +115,8 @@ export default function TransformerLabSettings({ }) { <Button onClick={async () => { const token = document.getElementsByName('hftoken')[0].value; - await fetch( - chatAPI.Endpoints.Config.Set( - 'HuggingfaceUserAccessToken', - token - ) - ); - // Now manually log in to huggingface + await fetch(chatAPI.Endpoints.Config.Set('HuggingfaceUserAccessToken', token)); + // Now manually log in to Huggingface await fetch(chatAPI.Endpoints.Models.HuggingFaceLogin()); hftokenmutate(token); canLogInToHuggingFaceMutate(); @@ -122,22 +126,23 @@ export default function TransformerLabSettings({ }) { Save </Button> <FormHelperText> - A Huggingface access token is required in order to access - certain models and datasets (those marked as "Gated"). + A Huggingface access token is required in order to access certain + models and datasets (those marked as "Gated"). </FormHelperText> <FormHelperText> - Documentation here: + Documentation here:{' '} <a href="https://huggingface.co/docs/hub/security-tokens" target="_blank" + rel="noreferrer" > https://huggingface.co/docs/hub/security-tokens </a> </FormHelperText> </FormControl> </> - )}{' '} - {wandbLoginStatus?.message === 'OK' ? ( + )} + {wandbLoginStatus?.message === 'OK' ? ( <Alert color="success">Login to Weights & Biases Successful</Alert> ) : ( <FormControl sx={{ maxWidth: '500px', mt: 2 }}> @@ -156,46 +161,16 @@ export default function TransformerLabSettings({ }) { </Button> </FormControl> )} - <FormControl sx={{ maxWidth: '500px', mt: 2 }}> - <FormLabel>OpenAI API Key</FormLabel> - <Input name="openaiKey" type="password" /> - <Button - onClick={async () => { - const token = document.getElementsByName('openaiKey')[0].value; - await fetch(chatAPI.Endpoints.Config.Set('OPENAI_API_KEY', token)); - await fetch(chatAPI.Endpoints.Models.SetOpenAIKey()); - const response = await fetch(chatAPI.Endpoints.Models.CheckOpenAIAPIKey()); - const result = await response.json(); - if (result.message === "OK") { - alert("Successfully set OpenAI API Key"); - } - }} + <Divider sx={{ mt: 2, mb: 2 }} /> + <Typography level="title-lg" marginBottom={2}> + AI Providers & Models: + </Typography> + {/* Clickable list option */} + <Button variant="soft" onClick={() => setShowProvidersPage(true)}> + Set API Keys for AI Providers + </Button> - sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }} - > - Save - </Button> - </FormControl> - <FormControl sx={{ maxWidth: '500px', mt: 2 }}> - <FormLabel>Anthropic API Key</FormLabel> - <Input name="anthropicKey" type="password" /> - <Button - onClick={async () => { - const token = document.getElementsByName('anthropicKey')[0].value; - await fetch(chatAPI.Endpoints.Config.Set('ANTHROPIC_API_KEY', token)); - await fetch(chatAPI.Endpoints.Models.SetAnthropicKey()); - const response = await fetch(chatAPI.Endpoints.Models.CheckAnthropicAPIKey()); - const result = await response.json(); - if (result.message === "OK") { - alert("Successfully set Anthropic API Key"); - } - }} - sx={{ marginTop: 1, width: '100px', alignSelf: 'flex-end' }} - > - Save - </Button> - </FormControl> - <Divider sx={{ mt: 2, mb: 2 }} />{' '} + <Divider sx={{ mt: 2, mb: 2 }} /> <Typography level="title-lg" marginBottom={2}> Application: </Typography> @@ -203,7 +178,7 @@ export default function TransformerLabSettings({ }) { variant="soft" onClick={() => { // find and delete all items in local storage that begin with oneTimePopup: - for (var key in localStorage) { + for (const key in localStorage) { if (key.startsWith('oneTimePopup')) { localStorage.removeItem(key); } @@ -212,7 +187,7 @@ export default function TransformerLabSettings({ }) { > Reset all Tutorial Popup Screens </Button> - <Divider sx={{ mt: 2, mb: 2 }} />{' '} + <Divider sx={{ mt: 2, mb: 2 }} /> <Typography level="title-lg" marginBottom={2}> View Jobs (debug):{' '} <IconButton onClick={() => jobsMutate()}> diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts index 6ec0a82b46cfa540ce391a50f13bfdbf93ea3428..c89db1941a13ef295be56a8b4a0c62059ddd8953 100644 --- a/src/renderer/lib/transformerlab-api-sdk.ts +++ b/src/renderer/lib/transformerlab-api-sdk.ts @@ -1089,6 +1089,7 @@ Endpoints.Models = { modelSource + '&model_id=' + modelId, + ImportFromLocalPath: (modelPath: string) => API_URL() + 'model/import_from_local_path?model_path=' + modelPath, HuggingFaceLogin: () => API_URL() + 'model/login_to_huggingface', @@ -1099,6 +1100,8 @@ Endpoints.Models = { SetAnthropicKey: () => API_URL() + 'model/set_anthropic_api_key', CheckOpenAIAPIKey: () => API_URL() + 'model/check_openai_api_key', CheckAnthropicAPIKey: () => API_URL() + 'model/check_anthropic_api_key', + SetCustomAPIKey: () => API_URL() + 'model/set_custom_api_key', + CheckCustomAPIKey: () => API_URL() + 'model/check_custom_api_key', }; Endpoints.Plugins = {