From 2b53e4e5b39f6c3195e4104ff2174aef5e4c3346 Mon Sep 17 00:00:00 2001
From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
Date: Fri, 15 Mar 2024 22:42:51 +0100
Subject: [PATCH] chore: mark as llama token as such

---
 src/LlamaCppCompletionIterator.php            |  1 +
 src/LlamaCppCompletionRequest.php             |  4 ++-
 src/LlamaCppCompletionToken.php               |  4 +++
 src/LlmPrompt/Plain.php                       | 22 ++++++++++++++++
 src/LlmPromptTemplate.php                     |  5 ++++
 src/LlmPromptTemplate/ChainPrompt.php         | 21 ++++++++++++++++
 src/LlmPromptTemplate/GemmaInstructChat.php   |  5 ++++
 src/LlmPromptTemplate/HermesChat.php          | 25 +++++++++++++++++++
 src/LlmPromptTemplate/MistralInstructChat.php |  5 ++++
 src/LlmPromptTemplate/Phi2Question.php        |  5 ++++
 src/LlmPromptTemplate/Plain.php               |  5 ++++
 ... LlamaCppSubjectActionPromptResponder.php} |  2 +-
 12 files changed, 102 insertions(+), 2 deletions(-)
 create mode 100644 src/LlmPrompt/Plain.php
 create mode 100644 src/LlmPromptTemplate/HermesChat.php
 rename src/WebSocketRPCResponder/{LlamaCppPromptResponder.php => LlamaCppSubjectActionPromptResponder.php} (97%)

diff --git a/src/LlamaCppCompletionIterator.php b/src/LlamaCppCompletionIterator.php
index 5b756c66..a17e7a4b 100644
--- a/src/LlamaCppCompletionIterator.php
+++ b/src/LlamaCppCompletionIterator.php
@@ -56,6 +56,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
             if ($unserializedToken) {
                 yield new LlamaCppCompletionToken(
                     content: $unserializedToken->content,
+                    isLastToken: $unserializedToken->stop,
                 );
             }
         }
diff --git a/src/LlamaCppCompletionRequest.php b/src/LlamaCppCompletionRequest.php
index e8cb7513..62fd4886 100644
--- a/src/LlamaCppCompletionRequest.php
+++ b/src/LlamaCppCompletionRequest.php
@@ -17,8 +17,10 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable
     public function jsonSerialize(): array
     {
         $parameters = [
-            'n_predict' => 400,
+            'cache_prompt' => true,
+            // 'n_predict' => 200,
             'prompt' => $this->promptTemplate->getPromptTemplateContent(),
+            'stop' => $this->promptTemplate->getStopWords(),
             'stream' => true,
         ];
 
diff --git a/src/LlamaCppCompletionToken.php b/src/LlamaCppCompletionToken.php
index 22e71249..2eef07a9 100644
--- a/src/LlamaCppCompletionToken.php
+++ b/src/LlamaCppCompletionToken.php
@@ -6,10 +6,14 @@ namespace Distantmagic\Resonance;
 
 use Stringable;
 
+/**
+ * @psalm-suppress PossiblyUnusedProperty used in apps
+ */
 readonly class LlamaCppCompletionToken implements Stringable
 {
     public function __construct(
         public string $content,
+        public bool $isLastToken,
     ) {}
 
     public function __toString(): string
diff --git a/src/LlmPrompt/Plain.php b/src/LlmPrompt/Plain.php
new file mode 100644
index 00000000..5ad71080
--- /dev/null
+++ b/src/LlmPrompt/Plain.php
@@ -0,0 +1,22 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance\LlmPrompt;
+
+use Distantmagic\Resonance\LlmPrompt;
+
+readonly class Plain extends LlmPrompt
+{
+    /**
+     * @param non-empty-string $prompt
+     */
+    public function __construct(
+        private string $prompt
+    ) {}
+
+    public function getPromptContent(): string
+    {
+        return $this->prompt;
+    }
+}
diff --git a/src/LlmPromptTemplate.php b/src/LlmPromptTemplate.php
index 29a4467a..9ff1758a 100644
--- a/src/LlmPromptTemplate.php
+++ b/src/LlmPromptTemplate.php
@@ -7,4 +7,9 @@ namespace Distantmagic\Resonance;
 abstract readonly class LlmPromptTemplate
 {
     abstract public function getPromptTemplateContent(): string;
+
+    /**
+     * @return list<non-empty-string>
+     */
+    abstract public function getStopWords(): array;
 }
diff --git a/src/LlmPromptTemplate/ChainPrompt.php b/src/LlmPromptTemplate/ChainPrompt.php
index 4ae2e9e1..806ae39b 100644
--- a/src/LlmPromptTemplate/ChainPrompt.php
+++ b/src/LlmPromptTemplate/ChainPrompt.php
@@ -6,12 +6,18 @@ namespace Distantmagic\Resonance\LlmPromptTemplate;
 
 use Distantmagic\Resonance\Attribute\Singleton;
 use Distantmagic\Resonance\LlmPromptTemplate;
+use Ds\Set;
 
 #[Singleton]
 readonly class ChainPrompt extends LlmPromptTemplate
 {
     private string $prompt;
 
+    /**
+     * @var list<non-empty-string>
+     */
+    private array $stopWords;
+
     /**
      * @param array<LlmPromptTemplate> $prompts
      */
@@ -19,15 +25,30 @@ readonly class ChainPrompt extends LlmPromptTemplate
     {
         $gluedPrompt = '';
 
+        /**
+         * @var Set<non-empty-string>
+         */
+        $gluedStopWords = new Set();
+
         foreach ($prompts as $prompt) {
             $gluedPrompt .= $prompt->getPromptTemplateContent();
+
+            foreach ($prompt->getStopWords() as $stopWord) {
+                $gluedStopWords->add($stopWord);
+            }
         }
 
         $this->prompt = $gluedPrompt;
+        $this->stopWords = $gluedStopWords->toArray();
     }
 
     public function getPromptTemplateContent(): string
     {
         return $this->prompt;
     }
+
+    public function getStopWords(): array
+    {
+        return $this->stopWords;
+    }
 }
diff --git a/src/LlmPromptTemplate/GemmaInstructChat.php b/src/LlmPromptTemplate/GemmaInstructChat.php
index 9990925d..8b56ae27 100644
--- a/src/LlmPromptTemplate/GemmaInstructChat.php
+++ b/src/LlmPromptTemplate/GemmaInstructChat.php
@@ -21,4 +21,9 @@ readonly class GemmaInstructChat extends LlmPromptTemplate
             $this->prompt,
         );
     }
+
+    public function getStopWords(): array
+    {
+        return ['<start_of_turn>', '<end_of_turn>'];
+    }
 }
diff --git a/src/LlmPromptTemplate/HermesChat.php b/src/LlmPromptTemplate/HermesChat.php
new file mode 100644
index 00000000..facaaba7
--- /dev/null
+++ b/src/LlmPromptTemplate/HermesChat.php
@@ -0,0 +1,25 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance\LlmPromptTemplate;
+
+use Distantmagic\Resonance\LlmPromptTemplate;
+
+readonly class HermesChat extends LlmPromptTemplate
+{
+    public function __construct(private string $prompt) {}
+
+    public function getPromptTemplateContent(): string
+    {
+        return sprintf(
+            '<|im_start|%s<|im_end|>',
+            $this->prompt,
+        );
+    }
+
+    public function getStopWords(): array
+    {
+        return ['<|im_start|>', '<|im_end|>'];
+    }
+}
diff --git a/src/LlmPromptTemplate/MistralInstructChat.php b/src/LlmPromptTemplate/MistralInstructChat.php
index e6799750..09cee6ba 100644
--- a/src/LlmPromptTemplate/MistralInstructChat.php
+++ b/src/LlmPromptTemplate/MistralInstructChat.php
@@ -17,4 +17,9 @@ readonly class MistralInstructChat extends LlmPromptTemplate
             $this->prompt,
         );
     }
+
+    public function getStopWords(): array
+    {
+        return ['[INST]', '[/INST]'];
+    }
 }
diff --git a/src/LlmPromptTemplate/Phi2Question.php b/src/LlmPromptTemplate/Phi2Question.php
index 5808fbd1..6a1079c6 100644
--- a/src/LlmPromptTemplate/Phi2Question.php
+++ b/src/LlmPromptTemplate/Phi2Question.php
@@ -17,4 +17,9 @@ readonly class Phi2Question extends LlmPromptTemplate
             $this->prompt,
         );
     }
+
+    public function getStopWords(): array
+    {
+        return ['Question:', 'Answer:'];
+    }
 }
diff --git a/src/LlmPromptTemplate/Plain.php b/src/LlmPromptTemplate/Plain.php
index e33fac9c..93db78d7 100644
--- a/src/LlmPromptTemplate/Plain.php
+++ b/src/LlmPromptTemplate/Plain.php
@@ -14,4 +14,9 @@ readonly class Plain extends LlmPromptTemplate
     {
         return $this->prompt;
     }
+
+    public function getStopWords(): array
+    {
+        return [];
+    }
 }
diff --git a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
similarity index 97%
rename from src/WebSocketRPCResponder/LlamaCppPromptResponder.php
rename to src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
index 94b0ad93..1387ad96 100644
--- a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php
+++ b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
@@ -26,7 +26,7 @@ use WeakMap;
  *
  * @template-extends WebSocketRPCResponder<TPayload>
  */
-abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder
+abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRPCResponder
 {
     /**
      * @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator>
-- 
GitLab