From 9c44a6cf02f24711efe6113e04fd1230e3a80c50 Mon Sep 17 00:00:00 2001
From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
Date: Fri, 22 Mar 2024 19:32:05 +0100
Subject: [PATCH] chore: report llm response chunk status

---
 src/LlamaCppClient.php                        | 16 +++++++---
 src/LlamaCppClientResponseChunk.php           | 13 +++++++++
 src/LlamaCppCompletionIterator.php            | 21 ++++++++++----
 src/LlamaCppCompletionToken.php               |  1 +
 src/ObservableTaskStatus.php                  | 22 ++++++++++----
 src/ObservableTaskTable.php                   |  2 +-
 src/ObservableTaskTimeoutIterator.php         |  7 ++++-
 src/PromptSubjectResponderAggregate.php       | 27 +++++++++++++----
 src/PromptSubjectResponse.php                 | 24 +++++++++------
 src/PromptSubjectResponseChunk.php            | 14 +++++++++
 src/SubjectActionTokenReader.php              |  5 ++++
 .../LlamaCppSubjectActionPromptResponder.php  | 29 +++++++++----------
 12 files changed, 135 insertions(+), 46 deletions(-)
 create mode 100644 src/LlamaCppClientResponseChunk.php
 create mode 100644 src/PromptSubjectResponseChunk.php

diff --git a/src/LlamaCppClient.php b/src/LlamaCppClient.php
index a05ca2d4..5f756c57 100644
--- a/src/LlamaCppClient.php
+++ b/src/LlamaCppClient.php
@@ -81,7 +81,7 @@ readonly class LlamaCppClient
             /**
              * @var object{ content: string }
              */
-            $token = $this->jsonSerializer->unserialize($responseChunk);
+            $token = $this->jsonSerializer->unserialize($responseChunk->chunk);
 
             yield new LlamaCppInfill(
                 after: $request->after,
@@ -160,7 +160,7 @@ readonly class LlamaCppClient
     }
 
     /**
-     * @return SwooleChannelIterator<string>
+     * @return SwooleChannelIterator<LlamaCppClientResponseChunk>
      */
     private function streamResponse(JsonSerializable $request, string $path): SwooleChannelIterator
     {
@@ -184,7 +184,10 @@ readonly class LlamaCppClient
             curl_setopt($curlHandle, CURLOPT_RETURNTRANSFER, false);
             curl_setopt($curlHandle, CURLOPT_URL, $this->llamaCppLinkBuilder->build($path));
             curl_setopt($curlHandle, CURLOPT_WRITEFUNCTION, function (CurlHandle $curlHandle, string $data) use ($channel): int {
-                if ($channel->push($data, $this->llamaCppConfiguration->completionTokenTimeout)) {
+                if ($channel->push(new LlamaCppClientResponseChunk(
+                    status: ObservableTaskStatus::Running,
+                    chunk: $data
+                ), $this->llamaCppConfiguration->completionTokenTimeout)) {
                     return strlen($data);
                 }
 
@@ -195,6 +198,11 @@ readonly class LlamaCppClient
 
                 if (CURLE_WRITE_ERROR !== $curlErrno) {
                     $this->logger->error(new CurlErrorMessage($curlHandle));
+
+                    $channel->push(new LlamaCppClientResponseChunk(
+                        status: ObservableTaskStatus::Failed,
+                        chunk: '',
+                    ));
                 }
             } else {
                 $this->assertStatusCode($curlHandle, 200);
@@ -202,7 +210,7 @@ readonly class LlamaCppClient
         });
 
         /**
-         * @var SwooleChannelIterator<string>
+         * @var SwooleChannelIterator<LlamaCppClientResponseChunk>
          */
         return new SwooleChannelIterator(
             channel: $channel,
diff --git a/src/LlamaCppClientResponseChunk.php b/src/LlamaCppClientResponseChunk.php
new file mode 100644
index 00000000..89de6198
--- /dev/null
+++ b/src/LlamaCppClientResponseChunk.php
@@ -0,0 +1,13 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+readonly class LlamaCppClientResponseChunk
+{
+    public function __construct(
+        public ObservableTaskStatus $status,
+        public string $chunk,
+    ) {}
+}
diff --git a/src/LlamaCppCompletionIterator.php b/src/LlamaCppCompletionIterator.php
index a17e7a4b..be43c8dc 100644
--- a/src/LlamaCppCompletionIterator.php
+++ b/src/LlamaCppCompletionIterator.php
@@ -16,7 +16,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
     private const COMPLETION_CHUNKED_DATA_PREFIX_LENGTH = 6;
 
     /**
-     * @param SwooleChannelIterator<string> $responseChunks
+     * @param SwooleChannelIterator<LlamaCppClientResponseChunk> $responseChunks
      */
     public function __construct(
         private JsonSerializer $jsonSerializer,
@@ -35,7 +35,19 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
         $previousChunk = '';
 
         foreach ($this->responseChunks as $responseChunk) {
-            $previousChunk .= $responseChunk;
+            if (ObservableTaskStatus::Failed === $responseChunk->status) {
+                $previousChunk = '';
+
+                yield new LlamaCppCompletionToken(
+                    content: '',
+                    isFailed: true,
+                    isLastToken: true,
+                );
+
+                break;
+            }
+
+            $previousChunk .= $responseChunk->chunk;
 
             /**
              * @var null|object{
@@ -56,6 +68,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
             if ($unserializedToken) {
                 yield new LlamaCppCompletionToken(
                     content: $unserializedToken->content,
+                    isFailed: false,
                     isLastToken: $unserializedToken->stop,
                 );
             }
@@ -72,8 +85,6 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
             return;
         }
 
-        if (!$this->responseChunks->channel->close()) {
-            throw new RuntimeException('Unable to close coroutine channel');
-        }
+        $this->responseChunks->channel->close();
     }
 }
diff --git a/src/LlamaCppCompletionToken.php b/src/LlamaCppCompletionToken.php
index 2eef07a9..4db369ea 100644
--- a/src/LlamaCppCompletionToken.php
+++ b/src/LlamaCppCompletionToken.php
@@ -13,6 +13,7 @@ readonly class LlamaCppCompletionToken implements Stringable
 {
     public function __construct(
         public string $content,
+        public bool $isFailed,
         public bool $isLastToken,
     ) {}
 
diff --git a/src/ObservableTaskStatus.php b/src/ObservableTaskStatus.php
index 3ea5fc38..62b33678 100644
--- a/src/ObservableTaskStatus.php
+++ b/src/ObservableTaskStatus.php
@@ -6,9 +6,21 @@ namespace Distantmagic\Resonance;
 
 enum ObservableTaskStatus: string
 {
-    case Cancelled = 'Cancelled';
-    case Failed = 'Failed';
-    case Finished = 'Finished';
-    case Pending = 'Pending';
-    case Running = 'Running';
+    case Cancelled = 'cancelled';
+    case Failed = 'failed';
+    case Finished = 'finished';
+    case Pending = 'pending';
+    case Running = 'running';
+    case TimedOut = 'timed_out';
+
+    public function isFinal(): bool
+    {
+        return match ($this) {
+            ObservableTaskStatus::Cancelled,
+            ObservableTaskStatus::Failed,
+            ObservableTaskStatus::Finished,
+            ObservableTaskStatus::TimedOut => true,
+            default => false,
+        };
+    }
 }
diff --git a/src/ObservableTaskTable.php b/src/ObservableTaskTable.php
index d88e9b49..7deb2d4b 100644
--- a/src/ObservableTaskTable.php
+++ b/src/ObservableTaskTable.php
@@ -108,7 +108,7 @@ readonly class ObservableTaskTable implements IteratorAggregate
                     }
                 }
 
-                if (ObservableTaskStatus::Running !== $statusUpdate->status) {
+                if ($statusUpdate->status->isFinal()) {
                     break;
                 }
             }
diff --git a/src/ObservableTaskTimeoutIterator.php b/src/ObservableTaskTimeoutIterator.php
index f72ba432..83b2f3d4 100644
--- a/src/ObservableTaskTimeoutIterator.php
+++ b/src/ObservableTaskTimeoutIterator.php
@@ -52,7 +52,12 @@ readonly class ObservableTaskTimeoutIterator implements IteratorAggregate
 
         $channel = new Channel(1);
 
-        $swooleTimeout = new SwooleTimeout(static function () use (&$generatorCoroutineId) {
+        $swooleTimeout = new SwooleTimeout(static function () use ($channel, &$generatorCoroutineId) {
+            $channel->push(new ObservableTaskStatusUpdate(
+                ObservableTaskStatus::TimedOut,
+                null,
+            ));
+
             if (is_int($generatorCoroutineId)) {
                 Coroutine::cancel($generatorCoroutineId);
             }
diff --git a/src/PromptSubjectResponderAggregate.php b/src/PromptSubjectResponderAggregate.php
index f97ca716..cd974eb1 100644
--- a/src/PromptSubjectResponderAggregate.php
+++ b/src/PromptSubjectResponderAggregate.php
@@ -16,10 +16,13 @@ readonly class PromptSubjectResponderAggregate
         private PromptSubjectResponderCollection $promptSubjectResponderCollection,
     ) {}
 
+    /**
+     * @return Generator<PromptSubjectResponseChunk>
+     */
     public function createResponseFromTokens(
         ?AuthenticatedUser $authenticatedUser,
         LlamaCppCompletionIterator $completion,
-        float $timeout = DM_POOL_CONNECTION_TIMEOUT,
+        float $inactivityTimeout = DM_POOL_CONNECTION_TIMEOUT,
     ): Generator {
         $subjectActionTokenReader = new SubjectActionTokenReader();
 
@@ -31,24 +34,36 @@ readonly class PromptSubjectResponderAggregate
 
                 break;
             }
+
+            if ($token->isFailed) {
+                yield new PromptSubjectResponseChunk(
+                    isFailed: true,
+                    isLastChunk: true,
+                    payload: '',
+                );
+            }
         }
 
         $action = $subjectActionTokenReader->getAction();
         $subject = $subjectActionTokenReader->getSubject();
 
+        if ($subjectActionTokenReader->isEmpty()) {
+            return;
+        }
+
         if ($subjectActionTokenReader->isUnknown() || !isset($action, $subject)) {
             yield from $this->respondWithSubjectAction(
                 $authenticatedUser,
                 'unknown',
                 'unknown',
-                $timeout,
+                $inactivityTimeout,
             );
         } else {
             yield from $this->respondWithSubjectAction(
                 $authenticatedUser,
                 $subject,
                 $action,
-                $timeout,
+                $inactivityTimeout,
             );
         }
     }
@@ -56,12 +71,14 @@ readonly class PromptSubjectResponderAggregate
     /**
      * @param non-empty-string $subject
      * @param non-empty-string $action
+     *
+     * @return Generator<PromptSubjectResponseChunk>
      */
     private function respondWithSubjectAction(
         ?AuthenticatedUser $authenticatedUser,
         string $subject,
         string $action,
-        float $timeout,
+        float $inactivityTimeout,
     ): Generator {
         $responder = $this
             ->promptSubjectResponderCollection
@@ -81,7 +98,7 @@ readonly class PromptSubjectResponderAggregate
         }
 
         $request = new PromptSubjectRequest($authenticatedUser);
-        $response = new PromptSubjectResponse($timeout);
+        $response = new PromptSubjectResponse($inactivityTimeout);
 
         SwooleCoroutineHelper::mustGo(static function () use ($request, $responder, $response) {
             $responder->respondToPromptSubject($request, $response);
diff --git a/src/PromptSubjectResponse.php b/src/PromptSubjectResponse.php
index c239e929..71d12be8 100644
--- a/src/PromptSubjectResponse.php
+++ b/src/PromptSubjectResponse.php
@@ -8,14 +8,14 @@ use IteratorAggregate;
 use Swoole\Coroutine\Channel;
 
 /**
- * @template-implements IteratorAggregate<mixed>
+ * @template-implements IteratorAggregate<PromptSubjectResponseChunk>
  */
 readonly class PromptSubjectResponse implements IteratorAggregate
 {
     private Channel $channel;
 
     public function __construct(
-        private float $timeout,
+        private float $inactivityTimeout,
     ) {
         $this->channel = new Channel(1);
     }
@@ -28,30 +28,36 @@ readonly class PromptSubjectResponse implements IteratorAggregate
     public function end(mixed $payload = null): void
     {
         try {
-            if (null !== $payload) {
-                $this->write($payload);
-            }
+            $this->channel->push(new PromptSubjectResponseChunk(
+                isFailed: false,
+                isLastChunk: true,
+                payload: $payload,
+            ));
         } finally {
             $this->channel->close();
         }
     }
 
     /**
-     * @return SwooleChannelIterator<mixed>
+     * @return SwooleChannelIterator<PromptSubjectResponseChunk>
      */
     public function getIterator(): SwooleChannelIterator
     {
         /**
-         * @var SwooleChannelIterator<mixed>
+         * @var SwooleChannelIterator<PromptSubjectResponseChunk>
          */
         return new SwooleChannelIterator(
             channel: $this->channel,
-            timeout: $this->timeout,
+            timeout: $this->inactivityTimeout,
         );
     }
 
     public function write(mixed $payload): void
     {
-        $this->channel->push($payload);
+        $this->channel->push(new PromptSubjectResponseChunk(
+            isFailed: false,
+            isLastChunk: false,
+            payload: $payload,
+        ));
     }
 }
diff --git a/src/PromptSubjectResponseChunk.php b/src/PromptSubjectResponseChunk.php
new file mode 100644
index 00000000..756ed319
--- /dev/null
+++ b/src/PromptSubjectResponseChunk.php
@@ -0,0 +1,14 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+readonly class PromptSubjectResponseChunk
+{
+    public function __construct(
+        public bool $isFailed,
+        public bool $isLastChunk,
+        public mixed $payload,
+    ) {}
+}
diff --git a/src/SubjectActionTokenReader.php b/src/SubjectActionTokenReader.php
index 062d451c..46c53838 100644
--- a/src/SubjectActionTokenReader.php
+++ b/src/SubjectActionTokenReader.php
@@ -37,6 +37,11 @@ class SubjectActionTokenReader
         return $this->subject;
     }
 
+    public function isEmpty(): bool
+    {
+        return !isset($this->action, $this->subject);
+    }
+
     public function isUnknown(): bool
     {
         return 'unknown' === $this->action || 'unknown' === $this->subject;
diff --git a/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
index 188961df..ef0f5304 100644
--- a/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
+++ b/src/WebSocketRPCResponder/LlamaCppSubjectActionPromptResponder.php
@@ -90,7 +90,7 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
                     yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null);
 
                     try {
-                        $this->onObservableRequest(
+                        yield from $this->onObservableRequest(
                             $webSocketAuthResolution,
                             $webSocketConnection,
                             $rpcRequest,
@@ -106,12 +106,14 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
 
     /**
      * @param RPCRequest<TPayload> $rpcRequest
+     *
+     * @return Generator<ObservableTaskStatusUpdate>
      */
     private function onObservableRequest(
         WebSocketAuthResolution $webSocketAuthResolution,
         WebSocketConnection $webSocketConnection,
         RPCRequest $rpcRequest,
-    ): void {
+    ): Generator {
         $request = new LlamaCppCompletionRequest(
             backusNaurFormGrammar: $this->subjectActionGrammar,
             promptTemplate: new ChainPrompt([
@@ -129,29 +131,24 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
             ->createResponseFromTokens(
                 authenticatedUser: $webSocketAuthResolution->authenticatedUser,
                 completion: $completion,
-                timeout: 1.0,
+                inactivityTimeout: 1.0,
             )
         ;
 
-        /**
-         * @var mixed $responseChunk explicitly mixed for typechecks
-         */
         foreach ($response as $responseChunk) {
+            if ($responseChunk->isFailed) {
+                yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Failed, null);
+
+                break;
+            }
+
             $this->onResponseChunk(
                 webSocketAuthResolution: $webSocketAuthResolution,
                 webSocketConnection: $webSocketConnection,
                 rpcRequest: $rpcRequest,
-                responseChunk: $responseChunk,
-                isLastChunk: false,
+                responseChunk: $responseChunk->payload,
+                isLastChunk: $responseChunk->isLastChunk,
             );
         }
-
-        $this->onResponseChunk(
-            webSocketAuthResolution: $webSocketAuthResolution,
-            webSocketConnection: $webSocketConnection,
-            rpcRequest: $rpcRequest,
-            responseChunk: '',
-            isLastChunk: true,
-        );
     }
 }
-- 
GitLab