diff --git a/src/renderer/components/Experiment/Interact/Batched.tsx b/src/renderer/components/Experiment/Interact/Batched.tsx index 9ab6d5eb793db67a5f0605234f50f1181dab2f53..aa70a56a07d2c8561fbd9e200f421697d4a14ed0 100644 --- a/src/renderer/components/Experiment/Interact/Batched.tsx +++ b/src/renderer/components/Experiment/Interact/Batched.tsx @@ -19,11 +19,10 @@ import { Alert, ListDivider, ListItemContent, + LinearProgress, } from '@mui/joy'; import * as chatAPI from '../../../lib/transformerlab-api-sdk'; -import { useEffect, useState } from 'react'; -import { useDebounce } from 'use-debounce'; -import useSWR from 'swr'; +import { useState } from 'react'; import { ConstructionIcon, FileStackIcon, @@ -37,27 +36,63 @@ import MainGenerationConfigKnobs from './MainGenerationConfigKnobs'; const fetcher = (url) => fetch(url).then((res) => res.json()); export default function Batched({ - chats, - setChats, - experimentInfo, - isThinking, - sendNewMessageToLLM, - stopStreaming, - experimentInfoMutate, tokenCount, - text, - debouncedText, defaultPromptConfigForModel = {}, - enableTools = false, - currentModelArchitecture, generationParameters, setGenerationParameters, - conversations, - conversationsIsLoading, - conversationsMutate, - setConversationId, - conversationId, + sendCompletionToLLM, + experimentInfo, }) { + const [isThinking, setIsThinking] = useState(false); + async function sendBatchOfQueries(key) { + const text = batchedQueriesList.find((query) => query.key === key).prompts; + + const currentModel = experimentInfo?.config?.foundation; + const adaptor = experimentInfo?.config?.adaptor; + + setIsThinking(true); + + var inferenceParams = ''; + + if (experimentInfo?.config?.inferenceParams) { + inferenceParams = experimentInfo?.config?.inferenceParams; + inferenceParams = JSON.parse(inferenceParams); + } + + const generationParamsJSON = experimentInfo?.config?.generationParams; + const generationParameters = JSON.parse(generationParamsJSON); + + try { + generationParameters.stop_str = JSON.parse( + generationParameters?.stop_str + ); + } catch (e) { + console.log('Error parsing stop strings as JSON'); + } + + const targetElement = document.getElementById('completion-textarea'); + targetElement.value = ''; + + const result = await chatAPI.sendBatchedCompletion( + currentModel, + adaptor, + text, + generationParameters?.temperature, + generationParameters?.maxTokens, + generationParameters?.topP, + false, + generationParameters?.stop_str + ); + + setIsThinking(false); + + console.log('result', result); + + targetElement.value = JSON.stringify(result?.choices, null, 2); + + return result?.text; + } + return ( <Sheet sx={{ @@ -94,7 +129,6 @@ export default function Batched({ /> </FormControl>{' '} <Typography - id="decorated-list-demo" level="body-xs" sx={{ textTransform: 'uppercase', fontWeight: 'lg', mb: 1 }} > @@ -104,12 +138,12 @@ export default function Batched({ sx={{ display: 'flex', border: '1px solid #ccc', - padding: 2, + padding: 1, flexDirection: 'column', height: '100%', }} > - <ListOfBatchedQueries /> + <ListOfBatchedQueries sendBatchOfQueries={sendBatchOfQueries} /> </Box> </Box> <Sheet @@ -147,9 +181,15 @@ export default function Batched({ gap: 1, }} > - <FormControl> + <FormControl sx={{ height: '100%' }}> <FormLabel>Result:</FormLabel> - <Textarea minRows={20} /> + {isThinking && <LinearProgress />} + <textarea + name="completion-textarea" + id="completion-textarea" + rows={20} + style={{ overflow: 'auto', height: '100%' }} + /> </FormControl> </Sheet> </Sheet> @@ -223,7 +263,7 @@ const batchedQueriesList = [ }, ]; -function ListOfBatchedQueries({}) { +function ListOfBatchedQueries({ sendBatchOfQueries }) { const [newQueryModalOpen, setNewQueryModalOpen] = useState(false); return ( @@ -238,7 +278,10 @@ function ListOfBatchedQueries({}) { <FileStackIcon /> </ListItemDecorator> <ListItemContent>{query.name}</ListItemContent> - <PlayIcon size="20px" onClick={() => alert('hi')} /> + <PlayIcon + size="20px" + onClick={() => sendBatchOfQueries(query.key)} + /> <PencilIcon size="20px" /> </ListItem> ))} diff --git a/src/renderer/components/Experiment/Interact/Interact.tsx b/src/renderer/components/Experiment/Interact/Interact.tsx index 1c8f2c9f44d5de23803402f09854e8a7521305c8..6d740d0e7dbe4ef1a4ad8e7b636e282b880c49c9 100644 --- a/src/renderer/components/Experiment/Interact/Interact.tsx +++ b/src/renderer/components/Experiment/Interact/Interact.tsx @@ -863,26 +863,12 @@ export default function Chat({ )} {mode == 'batched' && ( <Batched - key={conversationId} - chats={chats} - setChats={setChats} - experimentInfo={experimentInfo} - isThinking={isThinking} - sendNewMessageToLLM={sendNewMessageToLLM} - stopStreaming={stopStreaming} - experimentInfoMutate={experimentInfoMutate} tokenCount={tokenCount} - text={textToDebounce} - debouncedText={debouncedText} defaultPromptConfigForModel={defaultPromptConfigForModel} - currentModelArchitecture={currentModelArchitecture} generationParameters={generationParameters} setGenerationParameters={setGenerationParameters} - conversations={conversations} - conversationsIsLoading={conversationsIsLoading} - conversationsMutate={conversationsMutate} - setConversationId={setConversationId} - conversationId={conversationId} + sendCompletionToLLM={sendCompletionToLLM} + experimentInfo={experimentInfo} /> )} </Sheet> diff --git a/src/renderer/lib/transformerlab-api-sdk.ts b/src/renderer/lib/transformerlab-api-sdk.ts index 16b345b3ee591d4e50d1c3522e532bce4d88ee64..b899a9b379311a547270af3084bb18dc11f042a2 100644 --- a/src/renderer/lib/transformerlab-api-sdk.ts +++ b/src/renderer/lib/transformerlab-api-sdk.ts @@ -416,6 +416,62 @@ export async function sendCompletion( return null; } +export async function sendBatchedCompletion( + currentModel: string, + adaptor: string, + text: string[], + temperature: number = 0.7, + maxTokens: number = 256, + topP: number = 1.0, + useLongModelName = true, + stopString = null +) { + let model = ''; + if (useLongModelName) { + model = currentModel; + } else { + model = currentModel.split('/').slice(-1)[0]; + } + + if (adaptor && adaptor !== '') { + model = adaptor; + } + + const data = { + model: model, + stream: false, + prompt: text, + temperature, + max_tokens: maxTokens, + top_p: topP, + }; + + if (stopString) { + data.stop = stopString; + } + + let result; + var id = Math.random() * 1000; + + let response; + try { + response = await fetch(`${INFERENCE_SERVER_URL()}v1/completions`, { + method: 'POST', // or 'PUT' + headers: { + 'Content-Type': 'application/json', + accept: 'application/json', + }, + body: JSON.stringify(data), + }); + } catch (error) { + console.log('There was an error', error); + return null; + } + + result = await response.json(); + return result; +} + export async function callTool( function_name: String, function_args: Object = {}