From bffdc67a4ec788e9f93ab9b359848b19537e50a3 Mon Sep 17 00:00:00 2001
From: Mikhail Khludnev <mkhludnev@users.noreply.github.com>
Date: Thu, 21 Mar 2024 20:35:14 +0300
Subject: [PATCH] llms-vllm: fix VllmServer to work without CUDA-required vllm
 core. (#12003)

---
 docs/examples/llm/vllm.ipynb                  | 113 +++--------------
 .../llama_index/llms/vllm/base.py             |  51 ++++----
 .../llms/llama-index-llms-vllm/pyproject.toml |   2 +-
 .../tests/test_integration.py                 | 120 ++++++++++++++++++
 .../tests/test_llms_vllm.py                   |  25 +++-
 5 files changed, 193 insertions(+), 118 deletions(-)
 create mode 100644 llama-index-integrations/llms/llama-index-llms-vllm/tests/test_integration.py

diff --git a/docs/examples/llm/vllm.ipynb b/docs/examples/llm/vllm.ipynb
index 29f06de50c..57070674c9 100644
--- a/docs/examples/llm/vllm.ipynb
+++ b/docs/examples/llm/vllm.ipynb
@@ -5,10 +5,14 @@
    "id": "a577777f-a994-493b-bf39-4d19f4fdc5f8",
    "metadata": {},
    "source": [
-    "# Install Vllm  \n",
+    "# vLLM  \n",
+    "\n",
+    "There's two modes of using vLLM local and remote. Let's start form the former one, which requeries CUDA environment availabe locally. \n",
+    "\n",
+    "### Install vLLM\n",
+    "\n",
     "`pip install vllm` <br>\n",
-    "or if you want to compile you can compile from <br>\n",
-    "https://docs.vllm.ai/en/latest/getting_started/installation.html"
+    "or if you want to compile you can [compile from source](https://docs.vllm.ai/en/latest/getting_started/installation.html)"
    ]
   },
   {
@@ -16,7 +20,7 @@
    "id": "00d46316-8399-4731-8e97-5d8c8c436d99",
    "metadata": {},
    "source": [
-    "# Orca-7b Completion Example\n"
+    "### Orca-7b Completion Example\n"
    ]
   },
   {
@@ -155,7 +159,7 @@
    "id": "f67b22d6-37c0-4b42-b9c0-d259ee41fced",
    "metadata": {},
    "source": [
-    "# LLama-2-7b Completion Example\n"
+    "### LLama-2-7b Completion Example\n"
    ]
   },
   {
@@ -282,7 +286,7 @@
    "id": "5f3b7693-a0b2-46a0-ade8-1abb36691a49",
    "metadata": {},
    "source": [
-    "# mistral chat 7b Completion Example\n"
+    "### Mistral chat 7b Completion Example\n"
    ]
   },
   {
@@ -401,88 +405,11 @@
    "id": "9d02f871-56a1-4118-949a-2322177b58e5",
    "metadata": {},
    "source": [
-    "## Completion Example"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "51569399-1d66-43b3-a707-71220d19cd58",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from llama_index.core.llms.vllm import VllmServer"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "a4fc8e4c-0207-4f23-a5e9-e0851186e9f0",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "llm = VllmServer(\n",
-    "    api_url=\"http://localhost:8000/generate\", max_new_tokens=100, temperature=0\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "a6b95731-0361-44fe-9219-bed582cb9cbe",
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "[CompletionResponse(text='what is a black hole ?\\n\\nA black hole is a region in space where the gravitational pull is so strong that nothing, not even light, can escape from it. It is formed when a massive star collapses under its own gravity after it has exhausted its nuclear fuel. The boundary around the black hole, called the event horizon, marks the point of no return, beyond which anything that gets too close will be pulled in and cannot escape.', additional_kwargs={}, raw=None, delta=None)]"
-      ]
-     },
-     "execution_count": null,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "llm.complete(\"what is a black hole ?\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "898eadce-8349-417e-9cdc-3299f69894d6",
-   "metadata": {},
-   "source": [
-    "## Streaming Response"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "82b43656-2144-475b-bde3-fe830c9489e7",
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "CompletionResponse(text='what is a black hole?\\n\\nA black hole is a region in space where the gravitational pull is so strong that nothing, not even light, can escape from it. It is formed when a massive star collapses under its own gravity after it has exhausted its nuclear fuel. The boundary around the black hole, called the event horizon, marks the point of no return, beyond which anything that gets too close will be pulled in and cannot escape.', additional_kwargs={}, raw=None, delta=None)"
-      ]
-     },
-     "execution_count": null,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "list(llm.stream_complete(\"what is a black hole\"))[-1]"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c078acbc-f2a8-4b3b-8d3c-89858adf431e",
-   "metadata": {},
-   "source": [
-    "# Api Response\n",
-    "To setup the api you can follow the guide present here -> https://docs.vllm.ai/en/latest/serving/distributed_serving.html"
+    "# Calling vLLM via HTTP\n",
+    "\n",
+    "In this mode there is no need to install `vllm` model nor have CUDA available locally. To setup the vLLM API you can follow the guide present [here](https://docs.vllm.ai/en/latest/serving/distributed_serving.html). \n",
+    "Note: `llama-index-llms-vllm` module is a client for `vllm.entrypoints.api_server` which is only [a demo](https://github.com/vllm-project/vllm/blob/abfc4f3387c436d46d6701e9ba916de8f9ed9329/vllm/entrypoints/api_server.py#L2). <br>\n",
+    "If vLLM server is launched with `vllm.entrypoints.openai.api_server` as [OpenAI Compatible Server](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server)  or via [Docker](https://docs.vllm.ai/en/latest/serving/deploying_with_docker.html) you need `OpenAILike` class from `llama-index-llms-openai-like` [module](localai.ipynb#llamaindex-interaction)"
    ]
   },
   {
@@ -490,7 +417,7 @@
    "id": "d867ee6a-2a1b-4152-8616-1d2b507d9fab",
    "metadata": {},
    "source": [
-    "## completion Response "
+    "### Completion Response "
    ]
   },
   {
@@ -555,7 +482,7 @@
     }
    ],
    "source": [
-    "message = [ChatMessage(content=\"hello\", author=\"user\")]\n",
+    "message = [ChatMessage(content=\"hello\", role=\"user\")]\n",
     "llm.chat(message)"
    ]
   },
@@ -564,7 +491,7 @@
    "id": "6236b3e7-090e-45e0-807c-c3f8fab365d0",
    "metadata": {},
    "source": [
-    "## Streaming Response"
+    "### Streaming Response"
    ]
   },
   {
@@ -606,7 +533,7 @@
     }
    ],
    "source": [
-    "message = [ChatMessage(content=\"what is a black hole\", author=\"user\")]\n",
+    "message = [ChatMessage(content=\"what is a black hole\", role=\"user\")]\n",
     "[x for x in llm.stream_chat(message)][-1]"
    ]
   },
@@ -615,7 +542,7 @@
    "id": "71b8e03f-fcf5-4650-b701-585ca68f4bb8",
    "metadata": {},
    "source": [
-    "## Async Response"
+    "### Async Response"
    ]
   },
   {
diff --git a/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py b/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
index 5b34ca956f..aefe1b79ae 100644
--- a/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
@@ -142,14 +142,14 @@ class Vllm(LLM):
         pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
         output_parser: Optional[BaseOutputParser] = None,
     ) -> None:
-        try:
-            from vllm import LLM as VLLModel
-        except ImportError:
-            raise ImportError(
-                "Could not import vllm python package. "
-                "Please install it with `pip install vllm`."
-            )
-        if model != "":
+        if not api_url:
+            try:
+                from vllm import LLM as VLLModel
+            except ImportError:
+                raise ImportError(
+                    "Could not import vllm python package. "
+                    "Please install it with `pip install vllm`."
+                )
             self._client = VLLModel(
                 model=model,
                 tensor_parallel_size=tensor_parallel_size,
@@ -179,6 +179,7 @@ class Vllm(LLM):
             download_dir=download_dir,
             vllm_kwargs=vllm_kwargs,
             api_url=api_url,
+            callback_manager=callback_manager,
             system_prompt=system_prompt,
             messages_to_prompt=messages_to_prompt,
             completion_to_prompt=completion_to_prompt,
@@ -322,7 +323,6 @@ class VllmServer(Vllm):
         completion_to_prompt = completion_to_prompt or (lambda x: x)
         callback_manager = callback_manager or CallbackManager([])
 
-        model = ""
         super().__init__(
             model=model,
             temperature=temperature,
@@ -351,17 +351,18 @@ class VllmServer(Vllm):
     def class_name(cls) -> str:
         return "VllmServer"
 
+    def __del__(self) -> None:
+        ...
+
     @llm_completion_callback()
     def complete(
         self, prompt: str, formatted: bool = False, **kwargs: Any
-    ) -> List[CompletionResponse]:
+    ) -> CompletionResponse:
         kwargs = kwargs if kwargs else {}
         params = {**self._model_kwargs, **kwargs}
 
-        from vllm import SamplingParams
-
         # build sampling parameters
-        sampling_params = SamplingParams(**params).__dict__
+        sampling_params = dict(**params)
         sampling_params["prompt"] = prompt
         response = post_http_request(self.api_url, sampling_params, stream=False)
         output = get_response(response)
@@ -375,23 +376,25 @@ class VllmServer(Vllm):
         kwargs = kwargs if kwargs else {}
         params = {**self._model_kwargs, **kwargs}
 
-        from vllm import SamplingParams
-
-        # build sampling parameters
-        sampling_params = SamplingParams(**params).__dict__
+        sampling_params = dict(**params)
         sampling_params["prompt"] = prompt
         response = post_http_request(self.api_url, sampling_params, stream=True)
 
         def gen() -> CompletionResponseGen:
             response_str = ""
+            prev_prefix_len = len(prompt)
             for chunk in response.iter_lines(
                 chunk_size=8192, decode_unicode=False, delimiter=b"\0"
             ):
                 if chunk:
                     data = json.loads(chunk.decode("utf-8"))
 
-                    response_str += data["text"][0]
-                    yield CompletionResponse(text=response_str, delta=data["text"][0])
+                    increasing_concat = data["text"][0]
+                    pref = prev_prefix_len
+                    prev_prefix_len = len(increasing_concat)
+                    yield CompletionResponse(
+                        text=increasing_concat, delta=increasing_concat[pref:]
+                    )
 
         return gen()
 
@@ -409,10 +412,8 @@ class VllmServer(Vllm):
         kwargs = kwargs if kwargs else {}
         params = {**self._model_kwargs, **kwargs}
 
-        from vllm import SamplingParams
-
         # build sampling parameters
-        sampling_params = SamplingParams(**params).__dict__
+        sampling_params = dict(**params)
         sampling_params["prompt"] = prompt
 
         async def gen() -> CompletionResponseAsyncGen:
@@ -433,4 +434,8 @@ class VllmServer(Vllm):
     async def astream_chat(
         self, messages: Sequence[ChatMessage], **kwargs: Any
     ) -> ChatResponseAsyncGen:
-        return self.stream_chat(messages, **kwargs)
+        async def gen() -> ChatResponseAsyncGen:
+            for message in self.stream_chat(messages, **kwargs):
+                yield message
+
+        return gen()
diff --git a/llama-index-integrations/llms/llama-index-llms-vllm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vllm/pyproject.toml
index e8176a7309..d917eea1e3 100644
--- a/llama-index-integrations/llms/llama-index-llms-vllm/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-vllm/pyproject.toml
@@ -28,7 +28,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-vllm"
 readme = "README.md"
-version = "0.1.6"
+version = "0.1.7"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_integration.py b/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_integration.py
new file mode 100644
index 0000000000..45cf3b96cd
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_integration.py
@@ -0,0 +1,120 @@
+import asyncio
+import os
+
+import pytest
+from llama_index.core.base.llms.types import MessageRole, ChatMessage, ChatResponse
+
+from llama_index.core.base.llms.types import CompletionResponse
+
+
+def remote_vllm():
+    from llama_index.llms.vllm import VllmServer
+
+    return VllmServer(
+        api_url=os.environ.get("VLLM", "http://localhost:8000/generate"),
+    )
+
+
+def local_vllm():
+    from llama_index.llms.vllm import Vllm
+
+    return Vllm(
+        model="facebook/opt-350m",
+    )
+
+
+@pytest.mark.skip(reason="requires remotely running `vllm.entrypoints.api_server`")
+class TestVllmIntegration:
+    # replace to local_vllm(), it requires vllm installed and fail ..stream.. tests due to Not Implemented
+    vllm = remote_vllm()
+
+    def test_completion(self):
+        completion = self.vllm.complete("When AI hype is over?")
+        assert (
+            isinstance(completion, CompletionResponse)
+            and isinstance(completion.text, str)
+            and len(completion.text) > 0
+        )
+
+    def test_acompletion(self):
+        completion = asyncio.run(self.vllm.acomplete("When AI hype is over?"))
+        assert (
+            isinstance(completion, CompletionResponse)
+            and isinstance(completion.text, str)
+            and len(completion.text) > 0
+        )
+
+    def test_chat(self):
+        from llama_index.core.base.llms.types import ChatMessage
+
+        chat = self.vllm.chat(
+            [ChatMessage(content="When AI hype is over?", role=MessageRole.USER)]
+        )
+        assert (
+            isinstance(chat.message, ChatMessage)
+            and chat.message.role == MessageRole.ASSISTANT
+            and isinstance(chat.message.content, str)
+            and len(chat.message.content) > 0
+        )
+
+    def test_achat(self):
+        from llama_index.core.base.llms.types import ChatMessage
+
+        chat = asyncio.run(
+            self.vllm.achat(
+                [ChatMessage(content="When AI hype is over?", role=MessageRole.USER)]
+            )
+        )
+        assert (
+            isinstance(chat.message, ChatMessage)
+            and chat.message.role == MessageRole.ASSISTANT
+            and isinstance(chat.message.content, str)
+            and len(chat.message.content) > 0
+        )
+
+    def test_stream_completion(self):
+        prompt = "When AI hype is over?"
+        completion = list(self.vllm.stream_complete(prompt))[-1]
+        assert isinstance(completion, CompletionResponse)
+        assert completion.text.count(prompt) == 1
+        print(completion)
+
+    def test_astream_completion(self):
+        prompt = "When AI hype is over?"
+
+        async def concat():
+            return [c async for c in await self.vllm.astream_complete(prompt)]
+
+        completion = asyncio.run(concat())[-1]
+        assert isinstance(completion, CompletionResponse)
+        assert completion.text.count(prompt) == 1
+        print(completion)
+
+    def test_stream_chat(self):
+        prompt = "When AI hype is over?"
+        chat = list(
+            self.vllm.stream_chat([ChatMessage(content=prompt, role=MessageRole.USER)])
+        )[-1]
+        assert isinstance(chat, ChatResponse)
+        assert isinstance(chat.message, ChatMessage)
+        assert chat.message.role == MessageRole.ASSISTANT
+        assert chat.message.content.count(prompt) == 1
+        print(chat)
+
+    def test_astream_chat(self):
+        prompt = "When AI hype is over?"
+
+        async def concat():
+            return [
+                c
+                async for c in await self.vllm.astream_chat(
+                    [ChatMessage(content=prompt, role=MessageRole.USER)]
+                )
+            ]
+
+        chat = asyncio.run(concat())[-1]
+        assert isinstance(chat, ChatResponse)
+        assert isinstance(chat.message, ChatMessage)
+        assert chat.message.role == MessageRole.ASSISTANT
+        assert chat.message.content.count(prompt) == 1
+        print(chat)
diff --git a/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_llms_vllm.py b/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_llms_vllm.py
index 1be18dd2d4..b1ece4e7c4 100644
--- a/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_llms_vllm.py
+++ b/llama-index-integrations/llms/llama-index-llms-vllm/tests/test_llms_vllm.py
@@ -1,7 +1,30 @@
 from llama_index.core.base.llms.base import BaseLLM
-from llama_index.llms.vllm import Vllm
+from llama_index.core.callbacks import CallbackManager
 
 
 def test_embedding_class():
+    from llama_index.llms.vllm import Vllm
+
     names_of_base_classes = [b.__name__ for b in Vllm.__mro__]
     assert BaseLLM.__name__ in names_of_base_classes
+
+
+def test_server_class():
+    from llama_index.llms.vllm import VllmServer
+
+    names_of_base_classes = [b.__name__ for b in VllmServer.__mro__]
+    assert BaseLLM.__name__ in names_of_base_classes
+
+
+def test_server_callback() -> None:
+    from llama_index.llms.vllm import VllmServer
+
+    callback_manager = CallbackManager()
+    remote = VllmServer(
+        api_url="http://localhost:8000",
+        model="modelstub",
+        max_new_tokens=123,
+        callback_manager=callback_manager,
+    )
+    assert remote.callback_manager == callback_manager
+    del remote
-- 
GitLab