diff --git a/src/LlamaCppCompletionIterator.php b/src/LlamaCppCompletionIterator.php index 5b756c66e479dc49b1c2160a91773c4ae403c0dd..a17e7a4b30f9df11c90ea93f88c3eb12209aaea8 100644 --- a/src/LlamaCppCompletionIterator.php +++ b/src/LlamaCppCompletionIterator.php @@ -56,6 +56,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate if ($unserializedToken) { yield new LlamaCppCompletionToken( content: $unserializedToken->content, + isLastToken: $unserializedToken->stop, ); } } diff --git a/src/LlamaCppCompletionRequest.php b/src/LlamaCppCompletionRequest.php index e8cb75135e7e4463e7f5405e06372a6df22a26c6..62fd4886f5e14dadc4a86454d88f721ec3a63ae5 100644 --- a/src/LlamaCppCompletionRequest.php +++ b/src/LlamaCppCompletionRequest.php @@ -17,8 +17,10 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable public function jsonSerialize(): array { $parameters = [ - 'n_predict' => 400, + 'cache_prompt' => true, + // 'n_predict' => 200, 'prompt' => $this->promptTemplate->getPromptTemplateContent(), + 'stop' => $this->promptTemplate->getStopWords(), 'stream' => true, ]; diff --git a/src/LlamaCppCompletionToken.php b/src/LlamaCppCompletionToken.php index 22e71249237df6011e5c991403f566ffb11d4240..2eef07a91dc19d163ec948b948f575ee13849eb6 100644 --- a/src/LlamaCppCompletionToken.php +++ b/src/LlamaCppCompletionToken.php @@ -6,10 +6,14 @@ namespace Distantmagic\Resonance; use Stringable; +/** + * @psalm-suppress PossiblyUnusedProperty used in apps + */ readonly class LlamaCppCompletionToken implements Stringable { public function __construct( public string $content, + public bool $isLastToken, ) {} public function __toString(): string diff --git a/src/LlmPrompt/Plain.php b/src/LlmPrompt/Plain.php new file mode 100644 index 0000000000000000000000000000000000000000..5ad710800d100d9f8d6e25765b3b87512888fa28 --- /dev/null +++ b/src/LlmPrompt/Plain.php @@ -0,0 +1,22 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPrompt; + +use Distantmagic\Resonance\LlmPrompt; + +readonly class Plain extends LlmPrompt +{ + /** + * @param non-empty-string $prompt + */ + public function __construct( + private string $prompt + ) {} + + public function getPromptContent(): string + { + return $this->prompt; + } +} diff --git a/src/LlmPromptTemplate.php b/src/LlmPromptTemplate.php index 29a4467a3fdb1dfce4f6e21b4a08e932f922a5be..9ff1758a8195b5d76f29b5e38fd8e38c06519866 100644 --- a/src/LlmPromptTemplate.php +++ b/src/LlmPromptTemplate.php @@ -7,4 +7,9 @@ namespace Distantmagic\Resonance; abstract readonly class LlmPromptTemplate { abstract public function getPromptTemplateContent(): string; + + /** + * @return list<non-empty-string> + */ + abstract public function getStopWords(): array; } diff --git a/src/LlmPromptTemplate/ChainPrompt.php b/src/LlmPromptTemplate/ChainPrompt.php index 4ae2e9e170420347ee0d80d9bd445b3ecc3c2d09..806ae39b8e55a7a8b58d515a7296f3e2de1a3ce4 100644 --- a/src/LlmPromptTemplate/ChainPrompt.php +++ b/src/LlmPromptTemplate/ChainPrompt.php @@ -6,12 +6,18 @@ namespace Distantmagic\Resonance\LlmPromptTemplate; use Distantmagic\Resonance\Attribute\Singleton; use Distantmagic\Resonance\LlmPromptTemplate; +use Ds\Set; #[Singleton] readonly class ChainPrompt extends LlmPromptTemplate { private string $prompt; + /** + * @var list<non-empty-string> + */ + private array $stopWords; + /** * @param array<LlmPromptTemplate> $prompts */ @@ -19,15 +25,30 @@ readonly class ChainPrompt extends LlmPromptTemplate { $gluedPrompt = ''; + /** + * @var Set<non-empty-string> + */ + $gluedStopWords = new Set(); + foreach ($prompts as $prompt) { $gluedPrompt .= $prompt->getPromptTemplateContent(); + + foreach ($prompt->getStopWords() as $stopWord) { + $gluedStopWords->add($stopWord); + } } $this->prompt = $gluedPrompt; + $this->stopWords = $gluedStopWords->toArray(); } public function getPromptTemplateContent(): string { return $this->prompt; } + + public function getStopWords(): array + { + return $this->stopWords; + } } diff --git a/src/LlmPromptTemplate/GemmaInstructChat.php b/src/LlmPromptTemplate/GemmaInstructChat.php index 9990925d23958d3603b6ac4f0c7bfa2964f26dfc..8b56ae2784cb96a072d88b6522b44223e4fab5f5 100644 --- a/src/LlmPromptTemplate/GemmaInstructChat.php +++ b/src/LlmPromptTemplate/GemmaInstructChat.php @@ -21,4 +21,9 @@ readonly class GemmaInstructChat extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['<start_of_turn>', '<end_of_turn>']; + } } diff --git a/src/LlmPromptTemplate/HermesChat.php b/src/LlmPromptTemplate/HermesChat.php new file mode 100644 index 0000000000000000000000000000000000000000..facaaba713302519bdd764bf9cd60a24eeac1203 --- /dev/null +++ b/src/LlmPromptTemplate/HermesChat.php @@ -0,0 +1,25 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPromptTemplate; + +use Distantmagic\Resonance\LlmPromptTemplate; + +readonly class HermesChat extends LlmPromptTemplate +{ + public function __construct(private string $prompt) {} + + public function getPromptTemplateContent(): string + { + return sprintf( + '<|im_start|%s<|im_end|>', + $this->prompt, + ); + } + + public function getStopWords(): array + { + return ['<|im_start|>', '<|im_end|>']; + } +} diff --git a/src/LlmPromptTemplate/MistralInstructChat.php b/src/LlmPromptTemplate/MistralInstructChat.php index e6799750709b79ec7447138798987129c1dcf5e9..09cee6ba8e2838ce1fb8520c309038dae2fdc481 100644 --- a/src/LlmPromptTemplate/MistralInstructChat.php +++ b/src/LlmPromptTemplate/MistralInstructChat.php @@ -17,4 +17,9 @@ readonly class MistralInstructChat extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['[INST]', '[/INST]']; + } } diff --git a/src/LlmPromptTemplate/Phi2Question.php b/src/LlmPromptTemplate/Phi2Question.php index 5808fbd12f6ca933a6669913cfdc67f5f3196cab..6a1079c6b1fa94511f98119ca512cb98de65cdac 100644 --- a/src/LlmPromptTemplate/Phi2Question.php +++ b/src/LlmPromptTemplate/Phi2Question.php @@ -17,4 +17,9 @@ readonly class Phi2Question extends LlmPromptTemplate $this->prompt, ); } + + public function getStopWords(): array + { + return ['Question:', 'Answer:']; + } } diff --git a/src/LlmPromptTemplate/Plain.php b/src/LlmPromptTemplate/Plain.php index e33fac9cc2196aa5f857778900b40dd58c2963a1..93db78d74572e252b95cc7468f6dd371e47c5a13 100644 --- a/src/LlmPromptTemplate/Plain.php +++ b/src/LlmPromptTemplate/Plain.php @@ -14,4 +14,9 @@ readonly class Plain extends LlmPromptTemplate { return $this->prompt; } + + public function getStopWords(): array + { + return []; + } } diff --git a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php similarity index 97% rename from src/WebSocketRPCResponder/LlamaCppPromptResponder.php rename to src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php index 94b0ad938d2cc8cd556dcac42954c4ed56520f2b..1387ad96e6eca5bc92123cabfd316baf1c9cbf9b 100644 --- a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php +++ b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -26,7 +26,7 @@ use WeakMap; * * @template-extends WebSocketRPCResponder<TPayload> */ -abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder +abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRPCResponder { /** * @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator>