diff --git a/src/BackusNaurFormGrammar.php b/src/BackusNaurFormGrammar.php index 554403a3a0b19540901a6467f41af4b041109b42..c0abb4d1280a42d046b4741d2ce10c8058f638ed 100644 --- a/src/BackusNaurFormGrammar.php +++ b/src/BackusNaurFormGrammar.php @@ -4,14 +4,7 @@ declare(strict_types=1); namespace Distantmagic\Resonance; -use JsonSerializable; - -abstract readonly class BackusNaurFormGrammar implements JsonSerializable +abstract readonly class BackusNaurFormGrammar { abstract public function getGrammarContent(): string; - - public function jsonSerialize(): mixed - { - return $this->getGrammarContent(); - } } diff --git a/src/LlamaCppCompletionRequest.php b/src/LlamaCppCompletionRequest.php index 7422e3e4c1c75bd7d9309c85513cd69b8cb6b320..e8cb75135e7e4463e7f5405e06372a6df22a26c6 100644 --- a/src/LlamaCppCompletionRequest.php +++ b/src/LlamaCppCompletionRequest.php @@ -11,23 +11,25 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable public function __construct( public LlmPromptTemplate $promptTemplate, public ?BackusNaurFormGrammar $backusNaurFormGrammar = null, - public ?LlmSystemPrompt $llmSystemPrompt = null, + public ?LlmPrompt $llmSystemPrompt = null, ) {} public function jsonSerialize(): array { $parameters = [ 'n_predict' => 400, - 'prompt' => $this->promptTemplate, + 'prompt' => $this->promptTemplate->getPromptTemplateContent(), 'stream' => true, ]; if ($this->backusNaurFormGrammar) { - $parameters['grammar'] = $this->backusNaurFormGrammar; + $parameters['grammar'] = $this->backusNaurFormGrammar->getGrammarContent(); } if ($this->llmSystemPrompt) { - $parameters['system_prompt'] = $this->llmSystemPrompt; + $parameters['system_prompt'] = [ + 'prompt' => $this->llmSystemPrompt->getPromptContent(), + ]; } return $parameters; diff --git a/src/LlmPrompt.php b/src/LlmPrompt.php new file mode 100644 index 0000000000000000000000000000000000000000..53c04caf0edd8a8c8c308b693f05d4074d167ea7 --- /dev/null +++ b/src/LlmPrompt.php @@ -0,0 +1,10 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +abstract readonly class LlmPrompt +{ + abstract public function getPromptContent(): string; +} diff --git a/src/LlmSystemPrompt/SubjectActionSystemPrompt.php b/src/LlmPrompt/SubjectActionPrompt.php similarity index 95% rename from src/LlmSystemPrompt/SubjectActionSystemPrompt.php rename to src/LlmPrompt/SubjectActionPrompt.php index a68e8f6906e15a4fba841de9822b241d38892932..217555285a557f3ee852b464aac2db9b0fbfff25 100644 --- a/src/LlmSystemPrompt/SubjectActionSystemPrompt.php +++ b/src/LlmPrompt/SubjectActionPrompt.php @@ -2,16 +2,16 @@ declare(strict_types=1); -namespace Distantmagic\Resonance\LlmSystemPrompt; +namespace Distantmagic\Resonance\LlmPrompt; use Distantmagic\Resonance\Attribute\Singleton; -use Distantmagic\Resonance\LlmSystemPrompt; +use Distantmagic\Resonance\LlmPrompt; use Distantmagic\Resonance\PromptSubjectResponderCollection; use Distantmagic\Resonance\RespondsToPromptSubjectAttributeCollection; use Ds\Set; #[Singleton] -readonly class SubjectActionSystemPrompt extends LlmSystemPrompt +readonly class SubjectActionPrompt extends LlmPrompt { /** * @var non-empty-string $prompt diff --git a/src/LlmPromptTemplate.php b/src/LlmPromptTemplate.php index 4a24bb7dd4adb91faa5a4ed7deef1187e518746d..29a4467a3fdb1dfce4f6e21b4a08e932f922a5be 100644 --- a/src/LlmPromptTemplate.php +++ b/src/LlmPromptTemplate.php @@ -4,13 +4,7 @@ declare(strict_types=1); namespace Distantmagic\Resonance; -use JsonSerializable; -use Stringable; - -abstract readonly class LlmPromptTemplate implements JsonSerializable, Stringable +abstract readonly class LlmPromptTemplate { - public function jsonSerialize(): string - { - return (string) $this; - } + abstract public function getPromptTemplateContent(): string; } diff --git a/src/LlmPromptTemplate/ChainPrompt.php b/src/LlmPromptTemplate/ChainPrompt.php new file mode 100644 index 0000000000000000000000000000000000000000..4ae2e9e170420347ee0d80d9bd445b3ecc3c2d09 --- /dev/null +++ b/src/LlmPromptTemplate/ChainPrompt.php @@ -0,0 +1,33 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance\LlmPromptTemplate; + +use Distantmagic\Resonance\Attribute\Singleton; +use Distantmagic\Resonance\LlmPromptTemplate; + +#[Singleton] +readonly class ChainPrompt extends LlmPromptTemplate +{ + private string $prompt; + + /** + * @param array<LlmPromptTemplate> $prompts + */ + public function __construct(array $prompts) + { + $gluedPrompt = ''; + + foreach ($prompts as $prompt) { + $gluedPrompt .= $prompt->getPromptTemplateContent(); + } + + $this->prompt = $gluedPrompt; + } + + public function getPromptTemplateContent(): string + { + return $this->prompt; + } +} diff --git a/src/LlmPromptTemplate/MistralInstructChat.php b/src/LlmPromptTemplate/MistralInstructChat.php index 708cd97e08ca3b332153340eff6727a417dd1e7a..e6799750709b79ec7447138798987129c1dcf5e9 100644 --- a/src/LlmPromptTemplate/MistralInstructChat.php +++ b/src/LlmPromptTemplate/MistralInstructChat.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class MistralInstructChat extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return sprintf( '[INST]%s[/INST]', diff --git a/src/LlmPromptTemplate/Phi2Question.php b/src/LlmPromptTemplate/Phi2Question.php index b4c75401f370c7bda16a7e7e63fff058c359aaea..5808fbd12f6ca933a6669913cfdc67f5f3196cab 100644 --- a/src/LlmPromptTemplate/Phi2Question.php +++ b/src/LlmPromptTemplate/Phi2Question.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class Phi2Question extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return sprintf( "Question: %s\nAnswer: ", diff --git a/src/LlmPromptTemplate/Plain.php b/src/LlmPromptTemplate/Plain.php index e2074570c2a4f5e7147eec4e46e6a36346ce638b..e33fac9cc2196aa5f857778900b40dd58c2963a1 100644 --- a/src/LlmPromptTemplate/Plain.php +++ b/src/LlmPromptTemplate/Plain.php @@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate; readonly class Plain extends LlmPromptTemplate { - /** - * @param non-empty-string $prompt - */ public function __construct(private string $prompt) {} - public function __toString(): string + public function getPromptTemplateContent(): string { return $this->prompt; } diff --git a/src/LlmSystemPrompt.php b/src/LlmSystemPrompt.php deleted file mode 100644 index 70dd9289501d3d48f7a3a4698542164d0e0b778c..0000000000000000000000000000000000000000 --- a/src/LlmSystemPrompt.php +++ /dev/null @@ -1,19 +0,0 @@ -<?php - -declare(strict_types=1); - -namespace Distantmagic\Resonance; - -use JsonSerializable; - -abstract readonly class LlmSystemPrompt implements JsonSerializable -{ - abstract public function getPromptContent(): string; - - public function jsonSerialize(): array - { - return [ - 'prompt' => $this->getPromptContent(), - ]; - } -} diff --git a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppPromptResponder.php index c214ff3eaa10dfa1aca4fdd9f7c00b97fcf67a0a..583a8ae67f0c0eadee2da3c3e4e03c2c792b5772 100644 --- a/src/WebSocketRPCResponder/LlamaCppPromptResponder.php +++ b/src/WebSocketRPCResponder/LlamaCppPromptResponder.php @@ -8,8 +8,9 @@ use Distantmagic\Resonance\BackusNaurFormGrammar\SubjectActionGrammar; use Distantmagic\Resonance\LlamaCppClient; use Distantmagic\Resonance\LlamaCppCompletionIterator; use Distantmagic\Resonance\LlamaCppCompletionRequest; +use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt; use Distantmagic\Resonance\LlmPromptTemplate; -use Distantmagic\Resonance\LlmSystemPrompt\SubjectActionSystemPrompt; +use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt; use Distantmagic\Resonance\PromptSubjectResponderAggregate; use Distantmagic\Resonance\RPCNotification; use Distantmagic\Resonance\WebSocketAuthResolution; @@ -30,23 +31,25 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder */ private WeakMap $runningCompletions; + /** + * @param TPayload $payload + */ + abstract protected function getPromptFromPayload(mixed $payload): string; + abstract protected function onResponseChunk( WebSocketAuthResolution $webSocketAuthResolution, WebSocketConnection $webSocketConnection, mixed $responseChunk, ): void; - /** - * @param TPayload $payload - */ - abstract protected function toPromptTemplate(mixed $payload): LlmPromptTemplate; + abstract protected function toPromptTemplate(string $prompt): LlmPromptTemplate; public function __construct( private LlamaCppClient $llamaCppClient, private LoggerInterface $logger, private PromptSubjectResponderAggregate $promptSubjectResponderAggregate, private SubjectActionGrammar $subjectActionGrammar, - private SubjectActionSystemPrompt $subjectActionSystemPrompt, + private SubjectActionPrompt $subjectActionPrompt, ) { /** * @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator> @@ -70,8 +73,10 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder ): void { $request = new LlamaCppCompletionRequest( backusNaurFormGrammar: $this->subjectActionGrammar, - llmSystemPrompt: $this->subjectActionSystemPrompt, - promptTemplate: $this->toPromptTemplate($rpcNotification->payload), + promptTemplate: new ChainPrompt([ + $this->toPromptTemplate($this->subjectActionPrompt->getPromptContent()), + $this->toPromptTemplate($this->getPromptFromPayload($rpcNotification->payload)), + ]), ); $completion = $this->llamaCppClient->generateCompletion($request);