From 2b53e4e5b39f6c3195e4104ff2174aef5e4c3346 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com> Date: Fri, 15 Mar 2024 22:42:51 +0100 Subject: [PATCH] chore: mark as llama token as such --- src/LlamaCppCompletionIterator.php | 1 + src/LlamaCppCompletionRequest.php | 4 ++- src/LlamaCppCompletionToken.php | 4 +++ src/LlmPrompt/Plain.php | 22 ++++++++++++++++ src/LlmPromptTemplate.php | 5 ++++ src/LlmPromptTemplate/ChainPrompt.php | 21 ++++++++++++++++ src/LlmPromptTemplate/GemmaInstructChat.php | 5 ++++ src/LlmPromptTemplate/HermesChat.php | 25 +++++++++++++++++++ src/LlmPromptTemplate/MistralInstructChat.php | 5 ++++ src/LlmPromptTemplate/Phi2Question.php | 5 ++++ src/LlmPromptTemplate/Plain.php | 5 ++++ ... LlamaCppSubjectActionPromptResponder.php} | 2 +- 12 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 src/LlmPrompt/Plain.php create mode 100644 src/LlmPromptTemplate/HermesChat.php rename src/WebSocketRPCResponder/{LlamaCppPromptResponder.php => LlamaCppSubjectActionPromptResponder.php} (97%) diff --git a/src/LlamaCppCompletionIterator.php b/src/LlamaCppCompletionIterator.php index 5b756c66..a17e7a4b 100644 --- a/src/LlamaCppCompletionIterator.php +++ b/src/LlamaCppCompletionIterator.php @@ -56,6 +56,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate if ($unserializedToken) { yield new LlamaCppCompletionToken( content: $unserializedToken->content, + isLastToken: $unserializedToken->stop, ); } } diff --git a/src/LlamaCppCompletionRequest.php b/src/LlamaCppCompletionRequest.php index e8cb7513..62fd4886 100644 --- a/src/LlamaCppCompletionRequest.php +++ b/src/LlamaCppCompletionRequest.php @@ -17,8 +17,10 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable public function jsonSerialize(): array { $parameters = [ - 'n_predict' => 400, + 'cache_prompt' => true, + // 'n_predict' => 200, 'prompt' => $this->promptTemplate->getPromptTemplateContent(), + 'stop' => $this->promptTemplate->getStopWords(), 'stream' => true, ]; diff --git a/src/LlamaCppCompletionToken.php b/src/LlamaCppCompletionToken.php index 22e71249..2eef07a9 100644 --- a/src/LlamaCppCompletionToken.php +++ b/src/LlamaCppCompletionToken.php @@ -6,10 +6,14 @@ namespace Distantmagic\Resonance; use Stringable; +/** + * @psalm-suppress PossiblyUnusedProperty used in apps + */ readonly class LlamaCppCompletionToken implements Stringable { public function __construct( public string $content, + public bool $isLastToken, ) {} public function __toString(): string diff --git a/src/LlmPrompt/Plain.php b/src/LlmPrompt/Plain.php new file mode 100644 index 00000000..5ad71080 --- /dev/null +++ b/src/LlmPrompt/Plain.php @@ -0,0 +1,22 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPrompt; + +use Distantmagic\Resonance\LlmPrompt; + +readonly class Plain extends LlmPrompt +{ + /** + * @param non-empty-string $prompt + */ + public function __construct( + private string $prompt + ) {} + + public function getPromptContent(): string + { + return $this->prompt; + } +} diff --git a/src/LlmPromptTemplate.php b/src/LlmPromptTemplate.php index 29a4467a..9ff1758a 100644 --- a/src/LlmPromptTemplate.php +++ b/src/LlmPromptTemplate.php @@ -7,4 +7,9 @@ namespace Distantmagic\Resonance; abstract readonly class LlmPromptTemplate { abstract public function getPromptTemplateContent(): string; + + /** + * @return list<non-empty-string> + */ + abstract public function getStopWords(): array; } diff --git a/src/LlmPromptTemplate/ChainPrompt.php b/src/LlmPromptTemplate/ChainPrompt.php index 4ae2e9e1..806ae39b 100644 --- a/src/LlmPromptTemplate/ChainPrompt.php +++ b/src/LlmPromptTemplate/ChainPrompt.php @@ -6,12 +6,18 @@ namespace Distantmagic\Resonance\LlmPromptTemplate; use Distantmagic\Resonance\Attribute\Singleton; use Distantmagic\Resonance\LlmPromptTemplate; +use Ds\Set; #[Singleton] readonly class ChainPrompt extends LlmPromptTemplate { private string $prompt; + /** + * @var list<non-empty-string> + */ + private array $stopWords; + /** * @param array<LlmPromptTemplate> $prompts */ @@ -19,15 +25,30 @@ readonly class ChainPrompt extends LlmPromptTemplate { $gluedPrompt = ''; + /** + * @var Set<non-empty-string> + */ + $gluedStopWords = new Set(); + foreach ($prompts as $prompt) { $gluedPrompt .= $prompt->getPromptTemplateContent(); + + foreach ($prompt->getStopWords() as $stopWord) { + $gluedStopWords->add($stopWord); + } } $this->prompt = $gluedPrompt; + $this->stopWords = $gluedStopWords->toArray(); } public function getPromptTemplateContent(): string { return $this->prompt; } + + public function getStopWords(): array + { + return $this->stopWords; + } } diff --git a/src/LlmPromptTemplate/GemmaInstructChat.php b/src/LlmPromptTemplate/GemmaInstructChat.php index 9990925d..8b56ae27 100644 --- a/src/LlmPromptTemplate/GemmaInstructChat.php +++ b/src/LlmPromptTemplate/GemmaInstructChat.php @@ -21,4 +21,9 @@ readonly class GemmaInstructChat extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['<start_of_turn>', '<end_of_turn>']; + } } diff --git a/src/LlmPromptTemplate/HermesChat.php b/src/LlmPromptTemplate/HermesChat.php new file mode 100644 index 00000000..facaaba7 --- /dev/null +++ b/src/LlmPromptTemplate/HermesChat.php @@ -0,0 +1,25 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPromptTemplate; + +use Distantmagic\Resonance\LlmPromptTemplate; + +readonly class HermesChat extends LlmPromptTemplate +{ + public function __construct(private string $prompt) {} + + public function getPromptTemplateContent(): string + { + return sprintf( + '<|im_start|%s<|im_end|>', + $this->prompt, + ); + } + + public function getStopWords(): array + { + return ['<|im_start|>', '<|im_end|>']; + } +} diff --git a/src/LlmPromptTemplate/MistralInstructChat.php b/src/LlmPromptTemplate/MistralInstructChat.php index e6799750..09cee6ba 100644 --- a/src/LlmPromptTemplate/MistralInstructChat.php +++ b/src/LlmPromptTemplate/MistralInstructChat.php @@ -17,4 +17,9 @@ readonly class MistralInstructChat extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['[INST]', '[/INST]']; + } } diff --git a/src/LlmPromptTemplate/Phi2Question.php b/src/LlmPromptTemplate/Phi2Question.php index 5808fbd1..6a1079c6 100644 --- a/src/LlmPromptTemplate/Phi2Question.php +++ b/src/LlmPromptTemplate/Phi2Question.php @@ -17,4 +17,9 @@ readonly class Phi2Question extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['Question:', 'Answer:']; + } } diff --git a/src/LlmPromptTemplate/Plain.php b/src/LlmPromptTemplate/Plain.php index e33fac9c..93db78d7 100644 --- a/src/LlmPromptTemplate/Plain.php +++ b/src/LlmPromptTemplate/Plain.php @@ -14,4 +14,9 @@ readonly class Plain extends LlmPromptTemplate { return $this->prompt; } + + public function getStopWords(): array + { + return []; + } } diff --git a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php similarity index 97% rename from src/WebSocketRPCResponder/LlamaCppPromptResponder.php rename to src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php index 94b0ad93..1387ad96 100644 --- a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php +++ b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -26,7 +26,7 @@ use WeakMap; * * @template-extends WebSocketRPCResponder<TPayload> */ -abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder +abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRPCResponder { /** * @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator> -- GitLab