Skip to content
Snippets Groups Projects
Commit 68bbd073 authored by Mateusz Charytoniuk's avatar Mateusz Charytoniuk
Browse files

fix: do not swap the system prompt all the time

parent f504ca77
No related branches found
No related tags found
No related merge requests found
......@@ -4,14 +4,7 @@ declare(strict_types=1);
namespace Distantmagic\Resonance;
use JsonSerializable;
abstract readonly class BackusNaurFormGrammar implements JsonSerializable
abstract readonly class BackusNaurFormGrammar
{
abstract public function getGrammarContent(): string;
public function jsonSerialize(): mixed
{
return $this->getGrammarContent();
}
}
......@@ -11,23 +11,25 @@ readonly class LlamaCppCompletionRequest implements JsonSerializable
public function __construct(
public LlmPromptTemplate $promptTemplate,
public ?BackusNaurFormGrammar $backusNaurFormGrammar = null,
public ?LlmSystemPrompt $llmSystemPrompt = null,
public ?LlmPrompt $llmSystemPrompt = null,
) {}
public function jsonSerialize(): array
{
$parameters = [
'n_predict' => 400,
'prompt' => $this->promptTemplate,
'prompt' => $this->promptTemplate->getPromptTemplateContent(),
'stream' => true,
];
if ($this->backusNaurFormGrammar) {
$parameters['grammar'] = $this->backusNaurFormGrammar;
$parameters['grammar'] = $this->backusNaurFormGrammar->getGrammarContent();
}
if ($this->llmSystemPrompt) {
$parameters['system_prompt'] = $this->llmSystemPrompt;
$parameters['system_prompt'] = [
'prompt' => $this->llmSystemPrompt->getPromptContent(),
];
}
return $parameters;
......
......@@ -4,16 +4,7 @@ declare(strict_types=1);
namespace Distantmagic\Resonance;
use JsonSerializable;
abstract readonly class LlmSystemPrompt implements JsonSerializable
abstract readonly class LlmPrompt
{
abstract public function getPromptContent(): string;
public function jsonSerialize(): array
{
return [
'prompt' => $this->getPromptContent(),
];
}
}
......@@ -2,16 +2,16 @@
declare(strict_types=1);
namespace Distantmagic\Resonance\LlmSystemPrompt;
namespace Distantmagic\Resonance\LlmPrompt;
use Distantmagic\Resonance\Attribute\Singleton;
use Distantmagic\Resonance\LlmSystemPrompt;
use Distantmagic\Resonance\LlmPrompt;
use Distantmagic\Resonance\PromptSubjectResponderCollection;
use Distantmagic\Resonance\RespondsToPromptSubjectAttributeCollection;
use Ds\Set;
#[Singleton]
readonly class SubjectActionSystemPrompt extends LlmSystemPrompt
readonly class SubjectActionPrompt extends LlmPrompt
{
/**
* @var non-empty-string $prompt
......
......@@ -4,13 +4,7 @@ declare(strict_types=1);
namespace Distantmagic\Resonance;
use JsonSerializable;
use Stringable;
abstract readonly class LlmPromptTemplate implements JsonSerializable, Stringable
abstract readonly class LlmPromptTemplate
{
public function jsonSerialize(): string
{
return (string) $this;
}
abstract public function getPromptTemplateContent(): string;
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance\LlmPromptTemplate;
use Distantmagic\Resonance\Attribute\Singleton;
use Distantmagic\Resonance\LlmPromptTemplate;
#[Singleton]
readonly class ChainPrompt extends LlmPromptTemplate
{
private string $prompt;
/**
* @param array<LlmPromptTemplate> $prompts
*/
public function __construct(array $prompts)
{
$gluedPrompt = '';
foreach ($prompts as $prompt) {
$gluedPrompt .= $prompt->getPromptTemplateContent();
}
$this->prompt = $gluedPrompt;
}
public function getPromptTemplateContent(): string
{
return $this->prompt;
}
}
......@@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate;
readonly class MistralInstructChat extends LlmPromptTemplate
{
/**
* @param non-empty-string $prompt
*/
public function __construct(private string $prompt) {}
public function __toString(): string
public function getPromptTemplateContent(): string
{
return sprintf(
'[INST]%s[/INST]',
......
......@@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate;
readonly class Phi2Question extends LlmPromptTemplate
{
/**
* @param non-empty-string $prompt
*/
public function __construct(private string $prompt) {}
public function __toString(): string
public function getPromptTemplateContent(): string
{
return sprintf(
"Question: %s\nAnswer: ",
......
......@@ -8,12 +8,9 @@ use Distantmagic\Resonance\LlmPromptTemplate;
readonly class Plain extends LlmPromptTemplate
{
/**
* @param non-empty-string $prompt
*/
public function __construct(private string $prompt) {}
public function __toString(): string
public function getPromptTemplateContent(): string
{
return $this->prompt;
}
......
......@@ -8,8 +8,9 @@ use Distantmagic\Resonance\BackusNaurFormGrammar\SubjectActionGrammar;
use Distantmagic\Resonance\LlamaCppClient;
use Distantmagic\Resonance\LlamaCppCompletionIterator;
use Distantmagic\Resonance\LlamaCppCompletionRequest;
use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt;
use Distantmagic\Resonance\LlmPromptTemplate;
use Distantmagic\Resonance\LlmSystemPrompt\SubjectActionSystemPrompt;
use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt;
use Distantmagic\Resonance\PromptSubjectResponderAggregate;
use Distantmagic\Resonance\RPCNotification;
use Distantmagic\Resonance\WebSocketAuthResolution;
......@@ -30,23 +31,25 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder
*/
private WeakMap $runningCompletions;
/**
* @param TPayload $payload
*/
abstract protected function getPromptFromPayload(mixed $payload): string;
abstract protected function onResponseChunk(
WebSocketAuthResolution $webSocketAuthResolution,
WebSocketConnection $webSocketConnection,
mixed $responseChunk,
): void;
/**
* @param TPayload $payload
*/
abstract protected function toPromptTemplate(mixed $payload): LlmPromptTemplate;
abstract protected function toPromptTemplate(string $prompt): LlmPromptTemplate;
public function __construct(
private LlamaCppClient $llamaCppClient,
private LoggerInterface $logger,
private PromptSubjectResponderAggregate $promptSubjectResponderAggregate,
private SubjectActionGrammar $subjectActionGrammar,
private SubjectActionSystemPrompt $subjectActionSystemPrompt,
private SubjectActionPrompt $subjectActionPrompt,
) {
/**
* @var WeakMap<WebSocketConnection,LlamaCppCompletionIterator>
......@@ -70,8 +73,10 @@ abstract readonly class LlamaCppPromptResponder extends WebSocketRPCResponder
): void {
$request = new LlamaCppCompletionRequest(
backusNaurFormGrammar: $this->subjectActionGrammar,
llmSystemPrompt: $this->subjectActionSystemPrompt,
promptTemplate: $this->toPromptTemplate($rpcNotification->payload),
promptTemplate: new ChainPrompt([
$this->toPromptTemplate($this->subjectActionPrompt->getPromptContent()),
$this->toPromptTemplate($this->getPromptFromPayload($rpcNotification->payload)),
]),
);
$completion = $this->llamaCppClient->generateCompletion($request);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment