From 9c44a6cf02f24711efe6113e04fd1230e3a80c50 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com> Date: Fri, 22 Mar 2024 19:32:05 +0100 Subject: [PATCH] chore: report llm response chunk status --- src/LlamaCppClient.php | 16 +++++++--- src/LlamaCppClientResponseChunk.php | 13 +++++++++ src/LlamaCppCompletionIterator.php | 21 ++++++++++---- src/LlamaCppCompletionToken.php | 1 + src/ObservableTaskStatus.php | 22 ++++++++++---- src/ObservableTaskTable.php | 2 +- src/ObservableTaskTimeoutIterator.php | 7 ++++- src/PromptSubjectResponderAggregate.php | 27 +++++++++++++---- src/PromptSubjectResponse.php | 24 +++++++++------ src/PromptSubjectResponseChunk.php | 14 +++++++++ src/SubjectActionTokenReader.php | 5 ++++ .../LlamaCppSubjectActionPromptResponder.php | 29 +++++++++---------- 12 files changed, 135 insertions(+), 46 deletions(-) create mode 100644 src/LlamaCppClientResponseChunk.php create mode 100644 src/PromptSubjectResponseChunk.php diff --git a/src/LlamaCppClient.php b/src/LlamaCppClient.php index a05ca2d4..5f756c57 100644 --- a/src/LlamaCppClient.php +++ b/src/LlamaCppClient.php @@ -81,7 +81,7 @@ readonly class LlamaCppClient /** * @var object{ content: string } */ - $token = $this->jsonSerializer->unserialize($responseChunk); + $token = $this->jsonSerializer->unserialize($responseChunk->chunk); yield new LlamaCppInfill( after: $request->after, @@ -160,7 +160,7 @@ readonly class LlamaCppClient } /** - * @return SwooleChannelIterator<string> + * @return SwooleChannelIterator<LlamaCppClientResponseChunk> */ private function streamResponse(JsonSerializable $request, string $path): SwooleChannelIterator { @@ -184,7 +184,10 @@ readonly class LlamaCppClient curl_setopt($curlHandle, CURLOPT_RETURNTRANSFER, false); curl_setopt($curlHandle, CURLOPT_URL, $this->llamaCppLinkBuilder->build($path)); curl_setopt($curlHandle, CURLOPT_WRITEFUNCTION, function (CurlHandle $curlHandle, string $data) use ($channel): int { - if ($channel->push($data, $this->llamaCppConfiguration->completionTokenTimeout)) { + if ($channel->push(new LlamaCppClientResponseChunk( + status: ObservableTaskStatus::Running, + chunk: $data + ), $this->llamaCppConfiguration->completionTokenTimeout)) { return strlen($data); } @@ -195,6 +198,11 @@ readonly class LlamaCppClient if (CURLE_WRITE_ERROR !== $curlErrno) { $this->logger->error(new CurlErrorMessage($curlHandle)); + + $channel->push(new LlamaCppClientResponseChunk( + status: ObservableTaskStatus::Failed, + chunk: '', + )); } } else { $this->assertStatusCode($curlHandle, 200); @@ -202,7 +210,7 @@ readonly class LlamaCppClient }); /** - * @var SwooleChannelIterator<string> + * @var SwooleChannelIterator<LlamaCppClientResponseChunk> */ return new SwooleChannelIterator( channel: $channel, diff --git a/src/LlamaCppClientResponseChunk.php b/src/LlamaCppClientResponseChunk.php new file mode 100644 index 00000000..89de6198 --- /dev/null +++ b/src/LlamaCppClientResponseChunk.php @@ -0,0 +1,13 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +readonly class LlamaCppClientResponseChunk +{ + public function __construct( + public ObservableTaskStatus $status, + public string $chunk, + ) {} +} diff --git a/src/LlamaCppCompletionIterator.php b/src/LlamaCppCompletionIterator.php index a17e7a4b..be43c8dc 100644 --- a/src/LlamaCppCompletionIterator.php +++ b/src/LlamaCppCompletionIterator.php @@ -16,7 +16,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate private const COMPLETION_CHUNKED_DATA_PREFIX_LENGTH = 6; /** - * @param SwooleChannelIterator<string> $responseChunks + * @param SwooleChannelIterator<LlamaCppClientResponseChunk> $responseChunks */ public function __construct( private JsonSerializer $jsonSerializer, @@ -35,7 +35,19 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate $previousChunk = ''; foreach ($this->responseChunks as $responseChunk) { - $previousChunk .= $responseChunk; + if (ObservableTaskStatus::Failed === $responseChunk->status) { + $previousChunk = ''; + + yield new LlamaCppCompletionToken( + content: '', + isFailed: true, + isLastToken: true, + ); + + break; + } + + $previousChunk .= $responseChunk->chunk; /** * @var null|object{ @@ -56,6 +68,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate if ($unserializedToken) { yield new LlamaCppCompletionToken( content: $unserializedToken->content, + isFailed: false, isLastToken: $unserializedToken->stop, ); } @@ -72,8 +85,6 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate return; } - if (!$this->responseChunks->channel->close()) { - throw new RuntimeException('Unable to close coroutine channel'); - } + $this->responseChunks->channel->close(); } } diff --git a/src/LlamaCppCompletionToken.php b/src/LlamaCppCompletionToken.php index 2eef07a9..4db369ea 100644 --- a/src/LlamaCppCompletionToken.php +++ b/src/LlamaCppCompletionToken.php @@ -13,6 +13,7 @@ readonly class LlamaCppCompletionToken implements Stringable { public function __construct( public string $content, + public bool $isFailed, public bool $isLastToken, ) {} diff --git a/src/ObservableTaskStatus.php b/src/ObservableTaskStatus.php index 3ea5fc38..62b33678 100644 --- a/src/ObservableTaskStatus.php +++ b/src/ObservableTaskStatus.php @@ -6,9 +6,21 @@ namespace Distantmagic\Resonance; enum ObservableTaskStatus: string { - case Cancelled = 'Cancelled'; - case Failed = 'Failed'; - case Finished = 'Finished'; - case Pending = 'Pending'; - case Running = 'Running'; + case Cancelled = 'cancelled'; + case Failed = 'failed'; + case Finished = 'finished'; + case Pending = 'pending'; + case Running = 'running'; + case TimedOut = 'timed_out'; + + public function isFinal(): bool + { + return match ($this) { + ObservableTaskStatus::Cancelled, + ObservableTaskStatus::Failed, + ObservableTaskStatus::Finished, + ObservableTaskStatus::TimedOut => true, + default => false, + }; + } } diff --git a/src/ObservableTaskTable.php b/src/ObservableTaskTable.php index d88e9b49..7deb2d4b 100644 --- a/src/ObservableTaskTable.php +++ b/src/ObservableTaskTable.php @@ -108,7 +108,7 @@ readonly class ObservableTaskTable implements IteratorAggregate } } - if (ObservableTaskStatus::Running !== $statusUpdate->status) { + if ($statusUpdate->status->isFinal()) { break; } } diff --git a/src/ObservableTaskTimeoutIterator.php b/src/ObservableTaskTimeoutIterator.php index f72ba432..83b2f3d4 100644 --- a/src/ObservableTaskTimeoutIterator.php +++ b/src/ObservableTaskTimeoutIterator.php @@ -52,7 +52,12 @@ readonly class ObservableTaskTimeoutIterator implements IteratorAggregate $channel = new Channel(1); - $swooleTimeout = new SwooleTimeout(static function () use (&$generatorCoroutineId) { + $swooleTimeout = new SwooleTimeout(static function () use ($channel, &$generatorCoroutineId) { + $channel->push(new ObservableTaskStatusUpdate( + ObservableTaskStatus::TimedOut, + null, + )); + if (is_int($generatorCoroutineId)) { Coroutine::cancel($generatorCoroutineId); } diff --git a/src/PromptSubjectResponderAggregate.php b/src/PromptSubjectResponderAggregate.php index f97ca716..cd974eb1 100644 --- a/src/PromptSubjectResponderAggregate.php +++ b/src/PromptSubjectResponderAggregate.php @@ -16,10 +16,13 @@ readonly class PromptSubjectResponderAggregate private PromptSubjectResponderCollection $promptSubjectResponderCollection, ) {} + /** + * @return Generator<PromptSubjectResponseChunk> + */ public function createResponseFromTokens( ?AuthenticatedUser $authenticatedUser, LlamaCppCompletionIterator $completion, - float $timeout = DM_POOL_CONNECTION_TIMEOUT, + float $inactivityTimeout = DM_POOL_CONNECTION_TIMEOUT, ): Generator { $subjectActionTokenReader = new SubjectActionTokenReader(); @@ -31,24 +34,36 @@ readonly class PromptSubjectResponderAggregate break; } + + if ($token->isFailed) { + yield new PromptSubjectResponseChunk( + isFailed: true, + isLastChunk: true, + payload: '', + ); + } } $action = $subjectActionTokenReader->getAction(); $subject = $subjectActionTokenReader->getSubject(); + if ($subjectActionTokenReader->isEmpty()) { + return; + } + if ($subjectActionTokenReader->isUnknown() || !isset($action, $subject)) { yield from $this->respondWithSubjectAction( $authenticatedUser, 'unknown', 'unknown', - $timeout, + $inactivityTimeout, ); } else { yield from $this->respondWithSubjectAction( $authenticatedUser, $subject, $action, - $timeout, + $inactivityTimeout, ); } } @@ -56,12 +71,14 @@ readonly class PromptSubjectResponderAggregate /** * @param non-empty-string $subject * @param non-empty-string $action + * + * @return Generator<PromptSubjectResponseChunk> */ private function respondWithSubjectAction( ?AuthenticatedUser $authenticatedUser, string $subject, string $action, - float $timeout, + float $inactivityTimeout, ): Generator { $responder = $this ->promptSubjectResponderCollection @@ -81,7 +98,7 @@ readonly class PromptSubjectResponderAggregate } $request = new PromptSubjectRequest($authenticatedUser); - $response = new PromptSubjectResponse($timeout); + $response = new PromptSubjectResponse($inactivityTimeout); SwooleCoroutineHelper::mustGo(static function () use ($request, $responder, $response) { $responder->respondToPromptSubject($request, $response); diff --git a/src/PromptSubjectResponse.php b/src/PromptSubjectResponse.php index c239e929..71d12be8 100644 --- a/src/PromptSubjectResponse.php +++ b/src/PromptSubjectResponse.php @@ -8,14 +8,14 @@ use IteratorAggregate; use Swoole\Coroutine\Channel; /** - * @template-implements IteratorAggregate<mixed> + * @template-implements IteratorAggregate<PromptSubjectResponseChunk> */ readonly class PromptSubjectResponse implements IteratorAggregate { private Channel $channel; public function __construct( - private float $timeout, + private float $inactivityTimeout, ) { $this->channel = new Channel(1); } @@ -28,30 +28,36 @@ readonly class PromptSubjectResponse implements IteratorAggregate public function end(mixed $payload = null): void { try { - if (null !== $payload) { - $this->write($payload); - } + $this->channel->push(new PromptSubjectResponseChunk( + isFailed: false, + isLastChunk: true, + payload: $payload, + )); } finally { $this->channel->close(); } } /** - * @return SwooleChannelIterator<mixed> + * @return SwooleChannelIterator<PromptSubjectResponseChunk> */ public function getIterator(): SwooleChannelIterator { /** - * @var SwooleChannelIterator<mixed> + * @var SwooleChannelIterator<PromptSubjectResponseChunk> */ return new SwooleChannelIterator( channel: $this->channel, - timeout: $this->timeout, + timeout: $this->inactivityTimeout, ); } public function write(mixed $payload): void { - $this->channel->push($payload); + $this->channel->push(new PromptSubjectResponseChunk( + isFailed: false, + isLastChunk: false, + payload: $payload, + )); } } diff --git a/src/PromptSubjectResponseChunk.php b/src/PromptSubjectResponseChunk.php new file mode 100644 index 00000000..756ed319 --- /dev/null +++ b/src/PromptSubjectResponseChunk.php @@ -0,0 +1,14 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +readonly class PromptSubjectResponseChunk +{ + public function __construct( + public bool $isFailed, + public bool $isLastChunk, + public mixed $payload, + ) {} +} diff --git a/src/SubjectActionTokenReader.php b/src/SubjectActionTokenReader.php index 062d451c..46c53838 100644 --- a/src/SubjectActionTokenReader.php +++ b/src/SubjectActionTokenReader.php @@ -37,6 +37,11 @@ class SubjectActionTokenReader return $this->subject; } + public function isEmpty(): bool + { + return !isset($this->action, $this->subject); + } + public function isUnknown(): bool { return 'unknown' === $this->action || 'unknown' === $this->subject; diff --git a/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php index 188961df..ef0f5304 100644 --- a/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php +++ b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -90,7 +90,7 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); try { - $this->onObservableRequest( + yield from $this->onObservableRequest( $webSocketAuthResolution, $webSocketConnection, $rpcRequest, @@ -106,12 +106,14 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP /** * @param RPCRequest<TPayload> $rpcRequest + * + * @return Generator<ObservableTaskStatusUpdate> */ private function onObservableRequest( WebSocketAuthResolution $webSocketAuthResolution, WebSocketConnection $webSocketConnection, RPCRequest $rpcRequest, - ): void { + ): Generator { $request = new LlamaCppCompletionRequest( backusNaurFormGrammar: $this->subjectActionGrammar, promptTemplate: new ChainPrompt([ @@ -129,29 +131,24 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP ->createResponseFromTokens( authenticatedUser: $webSocketAuthResolution->authenticatedUser, completion: $completion, - timeout: 1.0, + inactivityTimeout: 1.0, ) ; - /** - * @var mixed $responseChunk explicitly mixed for typechecks - */ foreach ($response as $responseChunk) { + if ($responseChunk->isFailed) { + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Failed, null); + + break; + } + $this->onResponseChunk( webSocketAuthResolution: $webSocketAuthResolution, webSocketConnection: $webSocketConnection, rpcRequest: $rpcRequest, - responseChunk: $responseChunk, - isLastChunk: false, + responseChunk: $responseChunk->payload, + isLastChunk: $responseChunk->isLastChunk, ); } - - $this->onResponseChunk( - webSocketAuthResolution: $webSocketAuthResolution, - webSocketConnection: $webSocketConnection, - rpcRequest: $rpcRequest, - responseChunk: '', - isLastChunk: true, - ); } } -- GitLab