diff --git a/src/renderer/components/Experiment/DynamicPluginForm.tsx b/src/renderer/components/Experiment/DynamicPluginForm.tsx index bf43e7c4943122f123be668bff6614d5dffb852b..9a098c6a722b807ea0e515f5f7fbc9a15b617053 100644 --- a/src/renderer/components/Experiment/DynamicPluginForm.tsx +++ b/src/renderer/components/Experiment/DynamicPluginForm.tsx @@ -402,87 +402,89 @@ function ModelProviderWidget<T = any, S extends StrictRJSFSchema = RJSFSchema, F schema, multiple, } = props; - // const { enumOptions } = options; - - const _multiple = + // Determine multiple, defaulting to true. + const _multiple = typeof multiple !== 'undefined' ? Boolean(multiple) : typeof options.multiple !== 'undefined' ? Boolean(options.multiple) : true; - + // Disabled API key mapping as before. const isDisabledFilter = true; + const disabledEnvMap = { + claude: "ANTHROPIC_API_KEY", + openai: "OPENAI_API_KEY", + custom: "CUSTOM_API_KEY", + }; + const configKeysInOrder = Object.values(disabledEnvMap); + const configResults = configKeysInOrder.map((key) => + useSWR(chatAPI.Endpoints.Config.Get(key), fetcher) + ); + const configValues = React.useMemo(() => { + const values: Record<string, any> = {}; + configKeysInOrder.forEach((key, idx) => { + values[key] = configResults[idx]?.data; + }); + return values; + }, [configKeysInOrder, configResults]); + + // Define your custom mapping: label => stored value. + const labelToCustomValue: Record<string, string> = { + "Claude 3.7 Sonnet": "claude-3-7-sonnet-latest", + "Claude 3.5 Haiku": "claude-3-5-haiku-latest", + "OpenAI GPT 4o": "gpt-4o", + "OpenAI GPT 4o Mini": "gpt-4o-mini", + "Custom Model API": "custom-model-api", + "Local": "local" + }; + // Options list coming from the keys of your mapping. + const optionsList = Object.keys(labelToCustomValue); + // Inverse mapping: stored value => label. + const customValueToLabel = Object.entries(labelToCustomValue).reduce( + (acc, [label, custom]) => { + acc[custom] = label; + return acc; + }, + {} as Record<string, string> + ); + // Set default/current value. + const defaultValue = _multiple ? [] : ''; + const currentValue = value !== undefined ? value : defaultValue; - const enumOptions = [ - {label: 'Claude 3.5 Haiku', value: 'Claude 3.5 Haiku'}, - {label: 'Claude 3.5 Sonnet', value: 'Claude 3.5 Sonnet'}, - {label: 'OpenAI GPT 4o', value: 'OpenAI GPT 4o'}, - {label: 'OpenAI GPT 4o Mini', value: 'OpenAI GPT 4o Mini'}, - {label: 'Custom Model API', value: 'Custom Model API'}, - {label: 'Local', value: 'Local'} - ] - -// Keys here are the words in lower case with which the model names start and values are the ENV variables associated to those models -const disabledEnvMap = { - "claude": "ANTHROPIC_API_KEY", - "openai": "OPENAI_API_KEY", - "custom": "CUSTOM_API_KEY", -} -const configKeysInOrder = disabledEnvMap ? Object.values(disabledEnvMap) : []; - - -// Determine default value. -const defaultValue = _multiple ? [] : ''; -// Use the provided value or fallback to default. -const currentValue = value !== undefined ? value : defaultValue; - -// For each config key (in order received), call useSWR. -const configResults = configKeysInOrder.map((key) => - useSWR(chatAPI.Endpoints.Config.Get(key), fetcher) -); - -// Build a mapping of config key to its fetched value. -const configValues = React.useMemo(() => { - const values: Record<string, any> = {}; - configKeysInOrder.forEach((key, idx) => { - values[key] = configResults[idx]?.data; - }); - return values; -}, [configKeysInOrder, configResults]); - -// Map enumOptions into string values. -const processedOptionsValues = enumOptions.map((opt) => - typeof opt === 'object' ? opt.value : opt -); - -// Create a dictionary mapping each option to its disabled flag and message. -const combinedOptions = processedOptionsValues.reduce( - (acc: Record<string, { disabled: boolean; info?: string }>, opt) => { - const lower = opt.toLowerCase(); - let disabled = false; - let infoMessage = ""; - if (isDisabledFilter) { - // Loop through disabledEnvMap in insertion order. - for (const envKey in disabledEnvMap) { - if (lower.startsWith(envKey)) { - const configKey = disabledEnvMap[envKey]; - const configVal = configValues[configKey]; - disabled = !configVal || configVal === ''; - if (disabled) { - infoMessage = `Please set ${configKey} in settings`; + // Convert stored value(s) to display (labels). + const displayValue = _multiple + ? Array.isArray(currentValue) + ? currentValue.map(val => customValueToLabel[val] || val) + : [] + : (customValueToLabel[currentValue] || currentValue); + + // Build disabled mapping for options. + const combinedOptions = optionsList.reduce( + (acc: Record<string, { disabled: boolean; info?: string }>, opt) => { + const lower = opt.toLowerCase(); + let optDisabled = false; + let infoMessage = ""; + if (isDisabledFilter) { + for (const envKey in disabledEnvMap) { + if (lower.startsWith(envKey)) { + const configKey = disabledEnvMap[envKey]; + const configVal = configValues[configKey]; + optDisabled = !configVal || configVal === ''; + if (optDisabled) { + infoMessage = `Please set ${configKey} in settings`; + } + break; } - break; } } - } - acc[opt] = { disabled, info: disabled ? infoMessage : "" }; - return acc; - }, - {} -); + acc[opt] = { disabled: optDisabled, info: optDisabled ? infoMessage : "" }; + return acc; + }, + {} + ); return ( <> @@ -490,19 +492,23 @@ const combinedOptions = processedOptionsValues.reduce( multiple={_multiple} id={id} placeholder={schema.title || ''} - options={processedOptionsValues} + options={optionsList} getOptionLabel={(option) => - option + (combinedOptions[option]?.disabled ? " - " + combinedOptions[option].info : "") + option + + (combinedOptions[option]?.disabled ? " - " + combinedOptions[option].info : "") } getOptionDisabled={(option) => combinedOptions[option]?.disabled ?? false} - value={currentValue} + value={displayValue} onChange={(event, newValue) => { - onChange(newValue); + const storedValue = _multiple + ? newValue.map(label => labelToCustomValue[label] || label) + : (labelToCustomValue[newValue] || newValue); + onChange(storedValue); }} disabled={disabled || readonly} autoFocus={autofocus} /> - {/* Hidden input to capture the value on form submission */} + {/* Hidden input to capture the stored value on form submission */} <input type="hidden" name={id} @@ -622,7 +628,7 @@ const widgets: RegistryWidgetsType = { RangeWidget: CustomRange, SelectWidget: CustomSelectSimple, AutoCompleteWidget: CustomAutocompleteWidget, - ModelProviderWidget: ModelProviderWidget, + ModelProviderWidget: ModelProviderWidget }; const fetcher = (url) => fetch(url).then((res) => res.json());