diff --git a/examples/llama3.ts b/examples/llama3.ts index f17a6c2d6e7938a99e3259cbeb4c2c5014cba732..f5da9f23bb4152478e2fd783cb9cc15447152be4 100644 --- a/examples/llama3.ts +++ b/examples/llama3.ts @@ -2,8 +2,12 @@ import { ReplicateLLM } from "llamaindex"; (async () => { const tres = new ReplicateLLM({ model: "llama-3-70b-instruct" }); - const result = await tres.chat({ + const stream = await tres.chat({ messages: [{ content: "Hello, world!", role: "user" }], + stream: true, }); - console.log(result); + for await (const chunk of stream) { + process.stdout.write(chunk.delta); + } + console.log("\n\ndone"); })(); diff --git a/package.json b/package.json index 6b572316babdbb2533c98364c019913df77c5e87..be9513dffad6f995418d757a073c3427b510683c 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ "turbo": "^1.13.2", "typescript": "^5.4.5" }, - "packageManager": "pnpm@9.0.1+sha256.46d50ee2afecb42b185ebbd662dc7bdd52ef5be56bf035bb615cab81a75345df", + "packageManager": "pnpm@9.0.4", "pnpm": { "overrides": { "trim": "1.0.1", diff --git a/packages/core/src/llm/replicate_ai.ts b/packages/core/src/llm/replicate_ai.ts index acc807700930c8e053d8704dde47f8131795f362..99d5fb332aa80423628707a8ed2eba9f5c457722 100644 --- a/packages/core/src/llm/replicate_ai.ts +++ b/packages/core/src/llm/replicate_ai.ts @@ -9,7 +9,12 @@ import type { LLMChatParamsStreaming, MessageType, } from "./types.js"; -import { extractText, wrapLLMEvent } from "./utils.js"; +import { + extractText, + streamCallbacks, + streamConverter, + wrapLLMEvent, +} from "./utils.js"; export class ReplicateSession { replicateKey: string | null = null; @@ -332,9 +337,25 @@ If a question does not make any sense, or is not factually coherent, explain why replicateOptions.input.prompt_template = "{prompt}"; } - //TODO: Add streaming for this if (stream) { - throw new Error("Streaming not supported for ReplicateLLM"); + const controller = new AbortController(); + const stream = this.replicateSession.replicate.stream(api, { + ...replicateOptions, + signal: controller.signal, + }); + // replicate.stream is not closing if used as AsyncIterable, force closing after consumption with the AbortController + return streamCallbacks( + streamConverter(stream, (chunk) => { + if (chunk.event === "done") { + return null; + } + return { + raw: chunk, + delta: chunk.data, + }; + }), + { finished: () => controller.abort() }, + ); } //Non-streaming @@ -342,6 +363,7 @@ If a question does not make any sense, or is not factually coherent, explain why api, replicateOptions, ); + return { raw: response, message: { diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 4d90a66b0a2b506b1cf9d424aa2e4760c4bfa3c2..1b22a9c6e0b774a9d8b239d69fd7ae3bd19e54da 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -11,10 +11,29 @@ import type { export async function* streamConverter<S, D>( stream: AsyncIterable<S>, - converter: (s: S) => D, + converter: (s: S) => D | null, ): AsyncIterable<D> { for await (const data of stream) { - yield converter(data); + const newData = converter(data); + if (newData === null) { + return; + } + yield newData; + } +} + +export async function* streamCallbacks<S>( + stream: AsyncIterable<S>, + callbacks: { + finished?: (value?: S) => void; + }, +): AsyncIterable<S> { + let value: S | undefined; + for await (value of stream) { + yield value; + } + if (callbacks.finished) { + callbacks.finished(value); } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 34d730e580429c68539997bad3fea96e07545de8..cb4615cce9f7fe87ef028542522e917d2b389b4f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -148,7 +148,7 @@ importers: version: 1.0.11 llamaindex: specifier: latest - version: 0.2.9(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.5) + version: link:../packages/core mongodb: specifier: ^6.5.0 version: 6.5.0 @@ -15173,65 +15173,6 @@ snapshots: - typescript - utf-8-validate - llamaindex@0.2.9(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.5): - dependencies: - '@anthropic-ai/sdk': 0.20.6(encoding@0.1.13) - '@aws-crypto/sha256-js': 5.2.0 - '@datastax/astra-db-ts': 0.1.4 - '@grpc/grpc-js': 1.10.6 - '@llamaindex/cloud': 0.0.5(node-fetch@2.7.0(encoding@0.1.13)) - '@llamaindex/env': 0.0.7(@aws-crypto/sha256-js@5.2.0)(pathe@1.1.2)(readable-stream@4.5.2) - '@mistralai/mistralai': 0.1.3(encoding@0.1.13) - '@notionhq/client': 2.2.15(encoding@0.1.13) - '@pinecone-database/pinecone': 2.2.0 - '@qdrant/js-client-rest': 1.8.2(typescript@5.4.5) - '@types/lodash': 4.17.0 - '@types/node': 20.12.7 - '@types/papaparse': 5.3.14 - '@types/pg': 8.11.5 - '@xenova/transformers': 2.17.1 - '@zilliz/milvus2-sdk-node': 2.4.1 - ajv: 8.12.0 - assemblyai: 4.4.1 - chromadb: 1.7.3(cohere-ai@7.9.5(encoding@0.1.13))(encoding@0.1.13)(openai@4.38.0(encoding@0.1.13)) - cohere-ai: 7.9.5(encoding@0.1.13) - js-tiktoken: 1.0.11 - lodash: 4.17.21 - magic-bytes.js: 1.10.0 - mammoth: 1.7.1 - md-utils-ts: 2.0.0 - mongodb: 6.5.0 - notion-md-crawler: 0.0.2(encoding@0.1.13) - openai: 4.38.0(encoding@0.1.13) - papaparse: 5.4.1 - pathe: 1.1.2 - pdf2json: 3.0.5 - pg: 8.11.5 - pgvector: 0.1.8 - portkey-ai: 0.1.16 - rake-modified: 1.0.8 - replicate: 0.25.2 - string-strip-html: 13.4.8 - wikipedia: 2.1.2 - wink-nlp: 1.14.3 - transitivePeerDependencies: - - '@aws-sdk/credential-providers' - - '@google/generative-ai' - - '@mongodb-js/zstd' - - bufferutil - - debug - - encoding - - gcp-metadata - - kerberos - - mongodb-client-encryption - - node-fetch - - pg-native - - readable-stream - - snappy - - socks - - typescript - - utf-8-validate - load-yaml-file@0.2.0: dependencies: graceful-fs: 4.2.11