Skip to content
Snippets Groups Projects
Unverified Commit c22c50cc authored by Timothy Carambat's avatar Timothy Carambat Committed by GitHub
Browse files

Enable chat streaming for LLMs (#354)

* [Draft] Enable chat streaming for LLMs

* stream only, move sendChat to deprecated

* Update TODO deprecation comments
update console output color for streaming disabled
parent fa29003a
No related branches found
No related tags found
No related merge requests found
Showing
with 618 additions and 26 deletions
......@@ -12,6 +12,7 @@
"dependencies": {
"@esbuild-plugins/node-globals-polyfill": "^0.1.1",
"@metamask/jazzicon": "^2.0.0",
"@microsoft/fetch-event-source": "^2.0.1",
"@phosphor-icons/react": "^2.0.13",
"buffer": "^6.0.3",
"he": "^1.2.0",
......@@ -46,4 +47,4 @@
"tailwindcss": "^3.3.1",
"vite": "^4.3.0"
}
}
\ No newline at end of file
}
......@@ -72,7 +72,7 @@ const PromptReply = forwardRef(
role="assistant"
/>
<span
className={`whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }}
/>
</div>
......
......@@ -53,8 +53,10 @@ export default function ChatHistory({ history = [], workspace }) {
>
{history.map((props, index) => {
const isLastMessage = index === history.length - 1;
const isLastBotReply =
index === history.length - 1 && props.role === "assistant";
if (props.role === "assistant" && props.animate) {
if (isLastBotReply && props.animate) {
return (
<PromptReply
key={props.uuid}
......
......@@ -48,19 +48,36 @@ export default function ChatContainer({ workspace, knownHistory = [] }) {
return false;
}
const chatResult = await Workspace.sendChat(
// TODO: Delete this snippet once we have streaming stable.
// const chatResult = await Workspace.sendChat(
// workspace,
// promptMessage.userMessage,
// window.localStorage.getItem(`workspace_chat_mode_${workspace.slug}`) ??
// "chat",
// )
// handleChat(
// chatResult,
// setLoadingResponse,
// setChatHistory,
// remHistory,
// _chatHistory
// )
await Workspace.streamChat(
workspace,
promptMessage.userMessage,
window.localStorage.getItem(`workspace_chat_mode_${workspace.slug}`) ??
"chat"
);
handleChat(
chatResult,
setLoadingResponse,
setChatHistory,
remHistory,
_chatHistory
"chat",
(chatResult) =>
handleChat(
chatResult,
setLoadingResponse,
setChatHistory,
remHistory,
_chatHistory
)
);
return;
}
loadingResponse === true && fetchReply();
}, [loadingResponse, chatHistory, workspace]);
......
......@@ -358,3 +358,24 @@ dialog::backdrop {
.user-reply > div:first-of-type {
border: 2px solid white;
}
.reply > *:last-child::after {
content: "|";
animation: blink 1.5s steps(1) infinite;
color: white;
font-size: 14px;
}
@keyframes blink {
0% {
opacity: 0;
}
50% {
opacity: 1;
}
100% {
opacity: 0;
}
}
import { API_BASE } from "../utils/constants";
import { baseHeaders } from "../utils/request";
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { v4 } from "uuid";
const Workspace = {
new: async function (data = {}) {
......@@ -57,19 +59,44 @@ const Workspace = {
.catch(() => []);
return history;
},
sendChat: async function ({ slug }, message, mode = "query") {
const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, {
streamChat: async function ({ slug }, message, mode = "query", handleChat) {
const ctrl = new AbortController();
await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, {
method: "POST",
body: JSON.stringify({ message, mode }),
headers: baseHeaders(),
})
.then((res) => res.json())
.catch((e) => {
console.error(e);
return null;
});
return chatResult;
signal: ctrl.signal,
async onopen(response) {
if (response.ok) {
return; // everything's good
} else if (
response.status >= 400 &&
response.status < 500 &&
response.status !== 429
) {
throw new Error("Invalid Status code response.");
} else {
throw new Error("Unknown error");
}
},
async onmessage(msg) {
try {
const chatResult = JSON.parse(msg.data);
handleChat(chatResult);
} catch {}
},
onerror(err) {
handleChat({
id: v4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `An error occurred while streaming response. ${err.message}`,
});
ctrl.abort();
},
});
},
all: async function () {
const workspaces = await fetch(`${API_BASE}/workspaces`, {
......@@ -111,6 +138,22 @@ const Workspace = {
const data = await response.json();
return { response, data };
},
// TODO: Deprecated and should be removed from frontend.
sendChat: async function ({ slug }, message, mode = "query") {
const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, {
method: "POST",
body: JSON.stringify({ message, mode }),
headers: baseHeaders(),
})
.then((res) => res.json())
.catch((e) => {
console.error(e);
return null;
});
return chatResult;
},
};
export default Workspace;
......@@ -19,7 +19,8 @@ export default function handleChat(
sources,
closed: true,
error,
animate: true,
animate: false,
pending: false,
},
]);
_chatHistory.push({
......@@ -29,7 +30,8 @@ export default function handleChat(
sources,
closed: true,
error,
animate: true,
animate: false,
pending: false,
});
} else if (type === "textResponse") {
setLoadingResponse(false);
......@@ -42,7 +44,8 @@ export default function handleChat(
sources,
closed: close,
error,
animate: true,
animate: !close,
pending: false,
},
]);
_chatHistory.push({
......@@ -52,8 +55,36 @@ export default function handleChat(
sources,
closed: close,
error,
animate: true,
animate: !close,
pending: false,
});
} else if (type === "textResponseChunk") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
if (chatIdx !== -1) {
const existingHistory = { ..._chatHistory[chatIdx] };
const updatedHistory = {
...existingHistory,
content: existingHistory.content + textResponse,
sources,
error,
closed: close,
animate: !close,
pending: false,
};
_chatHistory[chatIdx] = updatedHistory;
} else {
_chatHistory.push({
uuid,
sources,
error,
content: textResponse,
role: "assistant",
closed: close,
animate: !close,
pending: false,
});
}
setChatHistory([..._chatHistory]);
}
}
......
......@@ -426,6 +426,11 @@
color "^0.11.3"
mersenne-twister "^1.1.0"
"@microsoft/fetch-event-source@^2.0.1":
version "2.0.1"
resolved "https://registry.yarnpkg.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d"
integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==
"@nodelib/fs.scandir@2.1.5":
version "2.1.5"
resolved "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz"
......
......@@ -6,10 +6,95 @@ const { validatedRequest } = require("../utils/middleware/validatedRequest");
const { WorkspaceChats } = require("../models/workspaceChats");
const { SystemSettings } = require("../models/systemSettings");
const { Telemetry } = require("../models/telemetry");
const {
streamChatWithWorkspace,
writeResponseChunk,
} = require("../utils/chats/stream");
function chatEndpoints(app) {
if (!app) return;
app.post(
"/workspace/:slug/stream-chat",
[validatedRequest],
async (request, response) => {
try {
const user = await userFromSession(request, response);
const { slug } = request.params;
const { message, mode = "query" } = reqBody(request);
const workspace = multiUserMode(response)
? await Workspace.getWithUser(user, { slug })
: await Workspace.get({ slug });
if (!workspace) {
response.sendStatus(400).end();
return;
}
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Content-Type", "text/event-stream");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Connection", "keep-alive");
response.flushHeaders();
if (multiUserMode(response) && user.role !== "admin") {
const limitMessagesSetting = await SystemSettings.get({
label: "limit_user_messages",
});
const limitMessages = limitMessagesSetting?.value === "true";
if (limitMessages) {
const messageLimitSetting = await SystemSettings.get({
label: "message_limit",
});
const systemLimit = Number(messageLimitSetting?.value);
if (!!systemLimit) {
const currentChatCount = await WorkspaceChats.count({
user_id: user.id,
createdAt: {
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
},
});
if (currentChatCount >= systemLimit) {
writeResponseChunk(response, {
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `You have met your maximum 24 hour chat quota of ${systemLimit} chats set by the instance administrators. Try again later.`,
});
return;
}
}
}
}
await streamChatWithWorkspace(response, workspace, message, mode, user);
await Telemetry.sendTelemetry("sent_chat", {
multiUserMode: multiUserMode(response),
LLMSelection: process.env.LLM_PROVIDER || "openai",
VectorDbSelection: process.env.VECTOR_DB || "pinecone",
});
response.end();
} catch (e) {
console.error(e);
writeResponseChunk(response, {
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: e.message,
});
response.end();
}
}
);
app.post(
"/workspace/:slug/chat",
[validatedRequest],
......
......@@ -27,6 +27,10 @@ class AnthropicLLM {
this.answerKey = v4().split("-")[0];
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
promptWindowLimit() {
switch (this.model) {
case "claude-instant-1":
......
......@@ -22,6 +22,10 @@ class AzureOpenAiLLM extends AzureOpenAiEmbedder {
};
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
// Sure the user selected a proper value for the token limit
// could be any of these https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-models
// and if undefined - assume it is the lowest end.
......
......@@ -27,6 +27,10 @@ class LMStudioLLM {
this.embedder = embedder;
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {
......@@ -103,6 +107,32 @@ Context:
return textResponse;
}
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
if (!this.model)
throw new Error(
`LMStudio chat: ${model} is not valid or defined for chat completion!`
);
const streamRequest = await this.lmstudio.createChatCompletion(
{
model: this.model,
temperature: Number(workspace?.openAiTemp ?? 0.7),
n: 1,
stream: true,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
},
{ responseType: "stream" }
);
return streamRequest;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
if (!this.model)
throw new Error(
......@@ -119,6 +149,24 @@ Context:
return data.choices[0].message.content;
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
if (!this.model)
throw new Error(
`LMStudio chat: ${this.model} is not valid or defined model for chat completion!`
);
const streamRequest = await this.lmstudio.createChatCompletion(
{
model: this.model,
stream: true,
messages,
temperature,
},
{ responseType: "stream" }
);
return streamRequest;
}
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
......
......@@ -19,6 +19,10 @@ class OpenAiLLM extends OpenAiEmbedder {
};
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
promptWindowLimit() {
switch (this.model) {
case "gpt-3.5-turbo":
......@@ -140,6 +144,33 @@ Context:
return textResponse;
}
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
const model = process.env.OPEN_MODEL_PREF;
if (!(await this.isValidChatCompletionModel(model)))
throw new Error(
`OpenAI chat: ${model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model,
stream: true,
temperature: Number(workspace?.openAiTemp ?? 0.7),
n: 1,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
},
{ responseType: "stream" }
);
return streamRequest;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
......@@ -156,6 +187,24 @@ Context:
return data.choices[0].message.content;
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`OpenAI chat: ${this.model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model: this.model,
stream: true,
messages,
temperature,
},
{ responseType: "stream" }
);
return streamRequest;
}
async compressMessages(promptArgs = {}, rawHistory = []) {
const { messageArrayCompressor } = require("../../helpers/chat");
const messageArray = this.constructPrompt(promptArgs);
......
......@@ -242,8 +242,11 @@ function chatPrompt(workspace) {
}
module.exports = {
recentChatHistory,
convertToPromptHistory,
convertToChatHistory,
chatWithWorkspace,
chatPrompt,
grepCommand,
VALID_COMMANDS,
};
const { v4: uuidv4 } = require("uuid");
const { WorkspaceChats } = require("../../models/workspaceChats");
const { getVectorDbClass, getLLMProvider } = require("../helpers");
const {
grepCommand,
recentChatHistory,
VALID_COMMANDS,
chatPrompt,
} = require(".");
function writeResponseChunk(response, data) {
response.write(`data: ${JSON.stringify(data)}\n\n`);
return;
}
async function streamChatWithWorkspace(
response,
workspace,
message,
chatMode = "chat",
user = null
) {
const uuid = uuidv4();
const command = grepCommand(message);
if (!!command && Object.keys(VALID_COMMANDS).includes(command)) {
const data = await VALID_COMMANDS[command](workspace, message, uuid, user);
writeResponseChunk(response, data);
return;
}
const LLMConnector = getLLMProvider();
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {
writeResponseChunk(response, {
id: uuid,
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `This message was moderated and will not be allowed. Violations for ${reasons.join(
", "
)} found.`,
});
return;
}
const messageLimit = workspace?.openAiHistory || 20;
const hasVectorizedSpace = await VectorDb.hasNamespace(workspace.slug);
const embeddingsCount = await VectorDb.namespaceCount(workspace.slug);
if (!hasVectorizedSpace || embeddingsCount === 0) {
// If there are no embeddings - chat like a normal LLM chat interface.
return await streamEmptyEmbeddingChat({
response,
uuid,
user,
message,
workspace,
messageLimit,
LLMConnector,
});
}
let completeText;
const { rawHistory, chatHistory } = await recentChatHistory(
user,
workspace,
messageLimit,
chatMode
);
const {
contextTexts = [],
sources = [],
message: error,
} = await VectorDb.performSimilaritySearch({
namespace: workspace.slug,
input: message,
LLMConnector,
similarityThreshold: workspace?.similarityThreshold,
});
// Failed similarity search.
if (!!error) {
writeResponseChunk(response, {
id: uuid,
type: "abort",
textResponse: null,
sources: [],
close: true,
error,
});
return;
}
// Compress message to ensure prompt passes token limit with room for response
// and build system messages based on inputs and history.
const messages = await LLMConnector.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: message,
contextTexts,
chatHistory,
},
rawHistory
);
// If streaming is not explicitly enabled for connector
// we do regular waiting of a response and send a single chunk.
if (LLMConnector.streamingEnabled() !== true) {
console.log(
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
);
completeText = await LLMConnector.getChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? 0.7,
});
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: completeText,
close: true,
error: false,
});
} else {
const stream = await LLMConnector.streamGetChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? 0.7,
});
completeText = await handleStreamResponses(response, stream, {
uuid,
sources,
});
}
await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources, type: chatMode },
user,
});
return;
}
async function streamEmptyEmbeddingChat({
response,
uuid,
user,
message,
workspace,
messageLimit,
LLMConnector,
}) {
let completeText;
const { rawHistory, chatHistory } = await recentChatHistory(
user,
workspace,
messageLimit
);
// If streaming is not explicitly enabled for connector
// we do regular waiting of a response and send a single chunk.
if (LLMConnector.streamingEnabled() !== true) {
console.log(
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
);
completeText = await LLMConnector.sendChat(
chatHistory,
message,
workspace,
rawHistory
);
writeResponseChunk(response, {
uuid,
type: "textResponseChunk",
textResponse: completeText,
sources: [],
close: true,
error: false,
});
} else {
const stream = await LLMConnector.streamChat(
chatHistory,
message,
workspace,
rawHistory
);
completeText = await handleStreamResponses(response, stream, {
uuid,
sources: [],
});
}
await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources: [], type: "chat" },
user,
});
return;
}
function handleStreamResponses(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;
return new Promise((resolve) => {
let fullText = "";
let chunk = "";
stream.data.on("data", (data) => {
const lines = data
?.toString()
?.split("\n")
.filter((line) => line.trim() !== "");
for (const line of lines) {
const message = chunk + line.replace(/^data: /, "");
// JSON chunk is incomplete and has not ended yet
// so we need to stitch it together. You would think JSON
// chunks would only come complete - but they don't!
if (message.slice(-3) !== "}]}") {
chunk += message;
continue;
} else {
chunk = "";
}
if (message == "[DONE]") {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
} else {
let finishReason;
let token = "";
try {
const json = JSON.parse(message);
token = json?.choices?.[0]?.delta?.content;
finishReason = json?.choices?.[0]?.finish_reason;
} catch {
continue;
}
if (token) {
fullText += token;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: token,
close: false,
error: false,
});
}
if (finishReason !== null) {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
}
}
}
});
});
}
module.exports = {
streamChatWithWorkspace,
writeResponseChunk,
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment