Skip to content
Snippets Groups Projects
Commit 9c44a6cf authored by Mateusz Charytoniuk's avatar Mateusz Charytoniuk
Browse files

chore: report llm response chunk status

parent 7cd158ed
No related branches found
No related tags found
No related merge requests found
......@@ -81,7 +81,7 @@ readonly class LlamaCppClient
/**
* @var object{ content: string }
*/
$token = $this->jsonSerializer->unserialize($responseChunk);
$token = $this->jsonSerializer->unserialize($responseChunk->chunk);
yield new LlamaCppInfill(
after: $request->after,
......@@ -160,7 +160,7 @@ readonly class LlamaCppClient
}
/**
* @return SwooleChannelIterator<string>
* @return SwooleChannelIterator<LlamaCppClientResponseChunk>
*/
private function streamResponse(JsonSerializable $request, string $path): SwooleChannelIterator
{
......@@ -184,7 +184,10 @@ readonly class LlamaCppClient
curl_setopt($curlHandle, CURLOPT_RETURNTRANSFER, false);
curl_setopt($curlHandle, CURLOPT_URL, $this->llamaCppLinkBuilder->build($path));
curl_setopt($curlHandle, CURLOPT_WRITEFUNCTION, function (CurlHandle $curlHandle, string $data) use ($channel): int {
if ($channel->push($data, $this->llamaCppConfiguration->completionTokenTimeout)) {
if ($channel->push(new LlamaCppClientResponseChunk(
status: ObservableTaskStatus::Running,
chunk: $data
), $this->llamaCppConfiguration->completionTokenTimeout)) {
return strlen($data);
}
......@@ -195,6 +198,11 @@ readonly class LlamaCppClient
if (CURLE_WRITE_ERROR !== $curlErrno) {
$this->logger->error(new CurlErrorMessage($curlHandle));
$channel->push(new LlamaCppClientResponseChunk(
status: ObservableTaskStatus::Failed,
chunk: '',
));
}
} else {
$this->assertStatusCode($curlHandle, 200);
......@@ -202,7 +210,7 @@ readonly class LlamaCppClient
});
/**
* @var SwooleChannelIterator<string>
* @var SwooleChannelIterator<LlamaCppClientResponseChunk>
*/
return new SwooleChannelIterator(
channel: $channel,
......
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
readonly class LlamaCppClientResponseChunk
{
public function __construct(
public ObservableTaskStatus $status,
public string $chunk,
) {}
}
......@@ -16,7 +16,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
private const COMPLETION_CHUNKED_DATA_PREFIX_LENGTH = 6;
/**
* @param SwooleChannelIterator<string> $responseChunks
* @param SwooleChannelIterator<LlamaCppClientResponseChunk> $responseChunks
*/
public function __construct(
private JsonSerializer $jsonSerializer,
......@@ -35,7 +35,19 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
$previousChunk = '';
foreach ($this->responseChunks as $responseChunk) {
$previousChunk .= $responseChunk;
if (ObservableTaskStatus::Failed === $responseChunk->status) {
$previousChunk = '';
yield new LlamaCppCompletionToken(
content: '',
isFailed: true,
isLastToken: true,
);
break;
}
$previousChunk .= $responseChunk->chunk;
/**
* @var null|object{
......@@ -56,6 +68,7 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
if ($unserializedToken) {
yield new LlamaCppCompletionToken(
content: $unserializedToken->content,
isFailed: false,
isLastToken: $unserializedToken->stop,
);
}
......@@ -72,8 +85,6 @@ readonly class LlamaCppCompletionIterator implements IteratorAggregate
return;
}
if (!$this->responseChunks->channel->close()) {
throw new RuntimeException('Unable to close coroutine channel');
}
$this->responseChunks->channel->close();
}
}
......@@ -13,6 +13,7 @@ readonly class LlamaCppCompletionToken implements Stringable
{
public function __construct(
public string $content,
public bool $isFailed,
public bool $isLastToken,
) {}
......
......@@ -6,9 +6,21 @@ namespace Distantmagic\Resonance;
enum ObservableTaskStatus: string
{
case Cancelled = 'Cancelled';
case Failed = 'Failed';
case Finished = 'Finished';
case Pending = 'Pending';
case Running = 'Running';
case Cancelled = 'cancelled';
case Failed = 'failed';
case Finished = 'finished';
case Pending = 'pending';
case Running = 'running';
case TimedOut = 'timed_out';
public function isFinal(): bool
{
return match ($this) {
ObservableTaskStatus::Cancelled,
ObservableTaskStatus::Failed,
ObservableTaskStatus::Finished,
ObservableTaskStatus::TimedOut => true,
default => false,
};
}
}
......@@ -108,7 +108,7 @@ readonly class ObservableTaskTable implements IteratorAggregate
}
}
if (ObservableTaskStatus::Running !== $statusUpdate->status) {
if ($statusUpdate->status->isFinal()) {
break;
}
}
......
......@@ -52,7 +52,12 @@ readonly class ObservableTaskTimeoutIterator implements IteratorAggregate
$channel = new Channel(1);
$swooleTimeout = new SwooleTimeout(static function () use (&$generatorCoroutineId) {
$swooleTimeout = new SwooleTimeout(static function () use ($channel, &$generatorCoroutineId) {
$channel->push(new ObservableTaskStatusUpdate(
ObservableTaskStatus::TimedOut,
null,
));
if (is_int($generatorCoroutineId)) {
Coroutine::cancel($generatorCoroutineId);
}
......
......@@ -16,10 +16,13 @@ readonly class PromptSubjectResponderAggregate
private PromptSubjectResponderCollection $promptSubjectResponderCollection,
) {}
/**
* @return Generator<PromptSubjectResponseChunk>
*/
public function createResponseFromTokens(
?AuthenticatedUser $authenticatedUser,
LlamaCppCompletionIterator $completion,
float $timeout = DM_POOL_CONNECTION_TIMEOUT,
float $inactivityTimeout = DM_POOL_CONNECTION_TIMEOUT,
): Generator {
$subjectActionTokenReader = new SubjectActionTokenReader();
......@@ -31,24 +34,36 @@ readonly class PromptSubjectResponderAggregate
break;
}
if ($token->isFailed) {
yield new PromptSubjectResponseChunk(
isFailed: true,
isLastChunk: true,
payload: '',
);
}
}
$action = $subjectActionTokenReader->getAction();
$subject = $subjectActionTokenReader->getSubject();
if ($subjectActionTokenReader->isEmpty()) {
return;
}
if ($subjectActionTokenReader->isUnknown() || !isset($action, $subject)) {
yield from $this->respondWithSubjectAction(
$authenticatedUser,
'unknown',
'unknown',
$timeout,
$inactivityTimeout,
);
} else {
yield from $this->respondWithSubjectAction(
$authenticatedUser,
$subject,
$action,
$timeout,
$inactivityTimeout,
);
}
}
......@@ -56,12 +71,14 @@ readonly class PromptSubjectResponderAggregate
/**
* @param non-empty-string $subject
* @param non-empty-string $action
*
* @return Generator<PromptSubjectResponseChunk>
*/
private function respondWithSubjectAction(
?AuthenticatedUser $authenticatedUser,
string $subject,
string $action,
float $timeout,
float $inactivityTimeout,
): Generator {
$responder = $this
->promptSubjectResponderCollection
......@@ -81,7 +98,7 @@ readonly class PromptSubjectResponderAggregate
}
$request = new PromptSubjectRequest($authenticatedUser);
$response = new PromptSubjectResponse($timeout);
$response = new PromptSubjectResponse($inactivityTimeout);
SwooleCoroutineHelper::mustGo(static function () use ($request, $responder, $response) {
$responder->respondToPromptSubject($request, $response);
......
......@@ -8,14 +8,14 @@ use IteratorAggregate;
use Swoole\Coroutine\Channel;
/**
* @template-implements IteratorAggregate<mixed>
* @template-implements IteratorAggregate<PromptSubjectResponseChunk>
*/
readonly class PromptSubjectResponse implements IteratorAggregate
{
private Channel $channel;
public function __construct(
private float $timeout,
private float $inactivityTimeout,
) {
$this->channel = new Channel(1);
}
......@@ -28,30 +28,36 @@ readonly class PromptSubjectResponse implements IteratorAggregate
public function end(mixed $payload = null): void
{
try {
if (null !== $payload) {
$this->write($payload);
}
$this->channel->push(new PromptSubjectResponseChunk(
isFailed: false,
isLastChunk: true,
payload: $payload,
));
} finally {
$this->channel->close();
}
}
/**
* @return SwooleChannelIterator<mixed>
* @return SwooleChannelIterator<PromptSubjectResponseChunk>
*/
public function getIterator(): SwooleChannelIterator
{
/**
* @var SwooleChannelIterator<mixed>
* @var SwooleChannelIterator<PromptSubjectResponseChunk>
*/
return new SwooleChannelIterator(
channel: $this->channel,
timeout: $this->timeout,
timeout: $this->inactivityTimeout,
);
}
public function write(mixed $payload): void
{
$this->channel->push($payload);
$this->channel->push(new PromptSubjectResponseChunk(
isFailed: false,
isLastChunk: false,
payload: $payload,
));
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
readonly class PromptSubjectResponseChunk
{
public function __construct(
public bool $isFailed,
public bool $isLastChunk,
public mixed $payload,
) {}
}
......@@ -37,6 +37,11 @@ class SubjectActionTokenReader
return $this->subject;
}
public function isEmpty(): bool
{
return !isset($this->action, $this->subject);
}
public function isUnknown(): bool
{
return 'unknown' === $this->action || 'unknown' === $this->subject;
......
......@@ -90,7 +90,7 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null);
try {
$this->onObservableRequest(
yield from $this->onObservableRequest(
$webSocketAuthResolution,
$webSocketConnection,
$rpcRequest,
......@@ -106,12 +106,14 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
/**
* @param RPCRequest<TPayload> $rpcRequest
*
* @return Generator<ObservableTaskStatusUpdate>
*/
private function onObservableRequest(
WebSocketAuthResolution $webSocketAuthResolution,
WebSocketConnection $webSocketConnection,
RPCRequest $rpcRequest,
): void {
): Generator {
$request = new LlamaCppCompletionRequest(
backusNaurFormGrammar: $this->subjectActionGrammar,
promptTemplate: new ChainPrompt([
......@@ -129,29 +131,24 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketRP
->createResponseFromTokens(
authenticatedUser: $webSocketAuthResolution->authenticatedUser,
completion: $completion,
timeout: 1.0,
inactivityTimeout: 1.0,
)
;
/**
* @var mixed $responseChunk explicitly mixed for typechecks
*/
foreach ($response as $responseChunk) {
if ($responseChunk->isFailed) {
yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Failed, null);
break;
}
$this->onResponseChunk(
webSocketAuthResolution: $webSocketAuthResolution,
webSocketConnection: $webSocketConnection,
rpcRequest: $rpcRequest,
responseChunk: $responseChunk,
isLastChunk: false,
responseChunk: $responseChunk->payload,
isLastChunk: $responseChunk->isLastChunk,
);
}
$this->onResponseChunk(
webSocketAuthResolution: $webSocketAuthResolution,
webSocketConnection: $webSocketConnection,
rpcRequest: $rpcRequest,
responseChunk: '',
isLastChunk: true,
);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment