Skip to content
Snippets Groups Projects
Unverified Commit 108634b9 authored by sweep-ai[bot]'s avatar sweep-ai[bot] Committed by GitHub
Browse files

Merge main into sweep/add-docstrings

parents 7678d319 4bb92be1
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
Added Meta strategy for Llama2
import { LlamaDeuce } from "llamaindex";
import { DeuceChatStrategy, LlamaDeuce } from "llamaindex";
(async () => {
const deuce = new LlamaDeuce();
const deuce = new LlamaDeuce({ chatStrategy: DeuceChatStrategy.META });
const result = await deuce.chat([{ content: "Hello, world!", role: "user" }]);
console.log(result);
})();
......@@ -177,23 +177,56 @@ export const ALL_AVAILABLE_LLAMADEUCE_MODELS = {
},
};
export enum DeuceChatStrategy {
A16Z = "a16z",
META = "meta",
METAWBOS = "metawbos",
//^ This is not exactly right because SentencePiece puts the BOS and EOS token IDs in after tokenization
// Unfortunately any string only API won't support these properly.
}
/**
* Llama2 LLM implementation
*/
export class LlamaDeuce implements LLM {
model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS;
chatStrategy: DeuceChatStrategy;
temperature: number;
maxTokens?: number;
replicateSession: ReplicateSession;
constructor(init?: Partial<LlamaDeuce>) {
this.model = init?.model ?? "Llama-2-70b-chat";
this.chatStrategy = init?.chatStrategy ?? DeuceChatStrategy.META;
this.temperature = init?.temperature ?? 0;
this.maxTokens = init?.maxTokens ?? undefined;
this.replicateSession = init?.replicateSession ?? new ReplicateSession();
}
mapMessageType(messageType: MessageType): string {
mapMessagesToPrompt(messages: ChatMessage[]): string {
if (this.chatStrategy === DeuceChatStrategy.A16Z) {
return this.mapMessagesToPromptA16Z(messages);
} else if (this.chatStrategy === DeuceChatStrategy.META) {
return this.mapMessagesToPromptMeta(messages);
} else if (this.chatStrategy === DeuceChatStrategy.METAWBOS) {
return this.mapMessagesToPromptMeta(messages, true);
} else {
return this.mapMessagesToPromptMeta(messages);
}
}
mapMessagesToPromptA16Z(messages: ChatMessage[]): string {
return (
messages.reduce((acc, message) => {
return (
(acc && `${acc}\n\n`) +
`${this.mapMessageTypeA16Z(message.role)}${message.content}`
);
}, "") + "\n\nAssistant:"
); // Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization
}
mapMessageTypeA16Z(messageType: MessageType): string {
switch (messageType) {
case "user":
return "User: ";
......@@ -206,26 +239,70 @@ export class LlamaDeuce implements LLM {
}
}
mapMessagesToPromptMeta(messages: ChatMessage[], withBos = false): string {
const DEFAULT_SYSTEM_PROMPT = `You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.`;
const B_SYS = "<<SYS>>\n";
const E_SYS = "\n<</SYS>>\n\n";
const B_INST = "[INST]";
const E_INST = "[/INST]";
const BOS = "<s>";
const EOS = "</s>";
if (messages.length === 0) {
return "";
}
if (messages[0].role === "system") {
const systemMessage = messages.shift()!;
const systemStr = `${B_SYS}${systemMessage.content}${E_SYS}`;
if (messages[1].role !== "user") {
throw new Error(
"LlamaDeuce: if there is a system message, the second message must be a user message."
);
}
const userContent = messages[0].content;
messages[0].content = `${systemStr}${userContent}`;
} else {
messages[0].content = `${B_SYS}${DEFAULT_SYSTEM_PROMPT}${E_SYS}${messages[0].content}`;
}
return messages.reduce((acc, message, index) => {
if (index % 2 === 0) {
return (
(withBos ? BOS : "") +
`${acc}${B_INST} ${message.content.trim()} ${E_INST}`
);
} else {
return `${acc} ${message.content.trim()} ` + (withBos ? EOS : ""); // Yes, the EOS comes after the space. This is not a mistake.
}
}, "");
}
async chat(
messages: ChatMessage[],
_parentEvent?: Event
): Promise<ChatResponse> {
const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model]
.replicateApi as `${string}/${string}:${string}`;
const prompt = this.mapMessagesToPrompt(messages);
const response = await this.replicateSession.replicate.run(api, {
input: {
prompt:
messages.reduce((acc, message) => {
return (
(acc && `${acc}\n\n`) +
`${this.mapMessageType(message.role)}${message.content}`
);
}, "") + "\n\nAssistant:", // Here we're differing from A16Z by omitting the space. Generally spaces at the end of prompts decrease performance due to tokenization
prompt,
},
});
return {
message: {
content: (response as Array<string>).join(""), // We need to do this because replicate returns a list of strings (for streaming functionality)
content: (response as Array<string>).join(""),
// We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function)
role: "assistant",
},
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment