From 68bbd073a3aec1eca8a5d6e006d839ecf36f94fc Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com> Date: Sat, 17 Feb 2024 10:22:25 +0100 Subject: [PATCH] fix: do not swap the system prompt all the time --- src/BackusNaurFormGrammar.php | 9 +---- src/LlamaCppCompletionRequest.php | 10 +++--- src/LlmPrompt.php | 10 ++++++ .../SubjectActionPrompt.php} | 6 ++-- src/LlmPromptTemplate.php | 10 ++---- src/LlmPromptTemplate/ChainPrompt.php | 33 +++++++++++++++++++ src/LlmPromptTemplate/MistralInstructChat.php | 5 +-- src/LlmPromptTemplate/Phi2Question.php | 5 +-- src/LlmPromptTemplate/Plain.php | 5 +-- src/LlmSystemPrompt.php | 19 ----------- .../LlamaCppPromptResponder.php | 21 +++++++----- 11 files changed, 71 insertions(+), 62 deletions(-) create mode 100644 src/LlmPrompt.php rename src/{LlmSystemPrompt/SubjectActionSystemPrompt.php => LlmPrompt/SubjectActionPrompt.php} (95%) create mode 100644 src/LlmPromptTemplate/ChainPrompt.php delete mode 100644 src/LlmSystemPrompt.php diff --git a/src/BackusNaurFormGrammar.php b/src/BackusNaurFormGrammar.php index 554403a3..c0abb4d1 100644 --- a/src/BackusNaurFormGrammar.php +++ b/src/BackusNaurFormGrammar.php @@ -4,14 +4,7 @@ declare(strict_types=1); namespace Distantmagic\Resonance; -use JsonSerializable; - -abstract readonly class BackusNaurFormGrammar implements JsonSerializable +abstract readonly class BackusNaurFormGrammar { abstract public function getGrammarContent(): string; - - public function jsonSerialize(): mixed - { - return $this->getGrammarContent(); - } } diff --git a/src/LlamaCppCompletionRequest.php b/src/LlamaCppCompletionRequest.php index 7422e3e4..e8cb7513 100644 --- a/src/LlamaCppCompletionRequest.php +++ b/src/LlamaCppCompletionRequest.php @@ -11,23 +11,25 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable public function __construct( public LlmPromptTemplate $promptTemplate, public ?BackusNaurFormGrammar $backusNaurFormGrammar = null, - public ?LlmSystemPrompt $llmSystemPrompt = null, + public ?LlmPrompt $llmSystemPrompt = null, ) {} public function jsonSerialize(): array { $parameters = [ 'n_predict' => 400, - 'prompt' => $this->promptTemplate, + 'prompt' => $this->promptTemplate->getPromptTemplateContent(), 'stream' => true, ]; if ($this->backusNaurFormGrammar) { - $parameters['grammar'] = $this->backusNaurFormGrammar; + $parameters['grammar'] = $this->backusNaurFormGrammar->getGrammarContent(); } if ($this->llmSystemPrompt) { - $parameters['system_prompt'] = $this->llmSystemPrompt; + $parameters['system_prompt'] = [ + 'prompt' => $this->llmSystemPrompt->getPromptContent(), + ]; } return $parameters; diff --git a/src/LlmPrompt.php b/src/LlmPrompt.php new file mode 100644 index 00000000..53c04caf --- /dev/null +++ b/src/LlmPrompt.php @@ -0,0 +1,10 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +abstract readonly class LlmPrompt +{ + abstract public function getPromptContent(): string; +} diff --git a/src/LlmSystemPrompt/SubjectActionSystemPrompt.php b/src/LlmPrompt/SubjectActionPrompt.php similarity index 95% rename from src/LlmSystemPrompt/SubjectActionSystemPrompt.php rename to src/LlmPrompt/SubjectActionPrompt.php index a68e8f69..21755528 100644 --- a/src/LlmSystemPrompt/SubjectActionSystemPrompt.php +++ b/src/LlmPrompt/SubjectActionPrompt.php @@ -2,16 +2,16 @@ declare(strict_types=1); -namespace Distantmagic\Resonance\LlmSystemPrompt; +namespace Distantmagic\Resonance\LlmPrompt; use Distantmagic\Resonance\Attribute\Singleton; -use Distantmagic\Resonance\LlmSystemPrompt; +use Distantmagic\Resonance\LlmPrompt; use Distantmagic\Resonance\PromptSubjectResponderCollection; use Distantmagic\Resonance\RespondsToPromptSubjectAttributeCollection; use Ds\Set; #[Singleton] -readonly class SubjectActionSystemPrompt extends LlmSystemPrompt +readonly class SubjectActionPrompt extends LlmPrompt { /** * @var non-empty-string $prompt diff --git a/src/LlmPromptTemplate.php b/src/LlmPromptTemplate.php index 4a24bb7d..29a4467a 100644 --- a/src/LlmPromptTemplate.php +++ b/src/LlmPromptTemplate.php @@ -4,13 +4,7 @@ declare(strict_types=1); namespace Distantmagic\Resonance; -use JsonSerializable; -use Stringable; - -abstract readonly class LlmPromptTemplate implements JsonSerializable, Stringable +abstract readonly class LlmPromptTemplate { - public function jsonSerialize(): string - { - return (string) $this; - } + abstract public function getPromptTemplateContent(): string; } diff --git a/src/LlmPromptTemplate/ChainPrompt.php b/src/LlmPromptTemplate/ChainPrompt.php new file mode 100644 index 00000000..4ae2e9e1 --- /dev/null +++ b/src/LlmPromptTemplate/ChainPrompt.php @@ -0,0 +1,33 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPromptTemplate; + +use Distantmagic\Resonance\Attribute\Singleton; +use Distantmagic\Resonance\LlmPromptTemplate; + +#[Singleton] +readonly class ChainPrompt extends LlmPromptTemplate +{ + private string $prompt; + + /** + * @param array<LlmPromptTemplate> $prompts + */ + public function __construct(array $prompts) + { + $gluedPrompt = ''; + + foreach ($prompts as $prompt) { + $gluedPrompt .= $prompt->getPromptTemplateContent(); + } + + $this->prompt = $gluedPrompt; + } + + public function getPromptTemplateContent(): string + { + return $this->prompt; + } +} diff --git a/src/LlmPromptTemplate/MistralInstructChat.php b/src/LlmPromptTemplate/MistralInstructChat.php index 708cd97e..e6799750 100644 --- a/src/LlmPromptTemplate/MistralInstructChat.php +++ b/src/LlmPromptTemplate/MistralInstructChat.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class MistralInstructChat extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return sprintf( '[INST]%s[/INST]', diff --git a/src/LlmPromptTemplate/Phi2Question.php b/src/LlmPromptTemplate/Phi2Question.php index b4c75401..5808fbd1 100644 --- a/src/LlmPromptTemplate/Phi2Question.php +++ b/src/LlmPromptTemplate/Phi2Question.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class Phi2Question extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return sprintf( "Question: %s\nAnswer: ", diff --git a/src/LlmPromptTemplate/Plain.php b/src/LlmPromptTemplate/Plain.php index e2074570..e33fac9c 100644 --- a/src/LlmPromptTemplate/Plain.php +++ b/src/LlmPromptTemplate/Plain.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class Plain extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return $this->prompt; } diff --git a/src/LlmSystemPrompt.php b/src/LlmSystemPrompt.php deleted file mode 100644 index 70dd9289..00000000 --- a/src/LlmSystemPrompt.php +++ /dev/null @@ -1,19 +0,0 @@ -<?php - -declare(strict_types=1); - -namespace Distantmagic\Resonance; - -use JsonSerializable; - -abstract readonly class LlmSystemPrompt implements JsonSerializable -{ - abstract public function getPromptContent(): string; - - public function jsonSerialize(): array - { - return [ - 'prompt' => $this->getPromptContent(), - ]; - } -} diff --git a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppPromptResponder.php index c214ff3e..583a8ae6 100644 --- a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php +++ b/src/WebSocketRPCResponder/LlamaCppPromptResponder.php @@ -8,8 +8,9 @@ use Distantmagic\Resonance\BackusNaurFormGrammar\SubjectActionGrammar; use Distantmagic\Resonance\LlamaCppClient; use Distantmagic\Resonance\LlamaCppCompletionIterator; use Distantmagic\Resonance\LlamaCppCompletionRequest; +use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt; use Distantmagic\Resonance\LlmPromptTemplate; -use Distantmagic\Resonance\LlmSystemPrompt\SubjectActionSystemPrompt; +use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt; use Distantmagic\Resonance\PromptSubjectResponderAggregate; use Distantmagic\Resonance\RPCNotification; use Distantmagic\Resonance\WebSocketAuthResolution; @@ -30,23 +31,25 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder */ private WeakMap $runningCompletions; + /** + * @param TPayload $payload + */ + abstract protected function getPromptFromPayload(mixed $payload): string; + abstract protected function onResponseChunk( WebSocketAuthResolution $webSocketAuthResolution, WebSocketConnection $webSocketConnection, mixed $responseChunk, ): void; - /** - * @param TPayload $payload - */ - abstract protected function toPromptTemplate(mixed $payload): LlmPromptTemplate; + abstract protected function toPromptTemplate(string $prompt): LlmPromptTemplate; public function __construct( private LlamaCppClient $llamaCppClient, private LoggerInterface $logger, private PromptSubjectResponderAggregate $promptSubjectResponderAggregate, private SubjectActionGrammar $subjectActionGrammar, - private SubjectActionSystemPrompt $subjectActionSystemPrompt, + private SubjectActionPrompt $subjectActionPrompt, ) { /** * @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator> @@ -70,8 +73,10 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder ): void { $request = new LlamaCppCompletionRequest( backusNaurFormGrammar: $this->subjectActionGrammar, - llmSystemPrompt: $this->subjectActionSystemPrompt, - promptTemplate: $this->toPromptTemplate($rpcNotification->payload), + promptTemplate: new ChainPrompt([ + $this->toPromptTemplate($this->subjectActionPrompt->getPromptContent()), + $this->toPromptTemplate($this->getPromptFromPayload($rpcNotification->payload)), + ]), ); $completion = $this->llamaCppClient->generateCompletion($request); -- GitLab