diff --git a/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py b/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
index 5fdd577c7dedbc56cce123b0299459694fb64b1b..93c795871b23fc28a86ab134e401c762b17d65a0 100644
--- a/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
@@ -212,6 +212,17 @@ class Vllm(LLM):
         }
         return {**base_kwargs}
 
+    def __del__(self) -> None:
+        import torch
+        from vllm.model_executor.parallel_utils.parallel_state import (
+            destroy_model_parallel,
+        )
+
+        destroy_model_parallel()
+        del self._client
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+
     def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
         return {
             **self._model_kwargs,
@@ -262,7 +273,8 @@ class Vllm(LLM):
     async def acomplete(
         self, prompt: str, formatted: bool = False, **kwargs: Any
     ) -> CompletionResponse:
-        raise (ValueError("Not Implemented"))
+        kwargs = kwargs if kwargs else {}
+        return self.complete(prompt, **kwargs)
 
     @llm_chat_callback()
     async def astream_chat(