From 439e1b67b1a30cb08a87bfa883fed1af9db19b81 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Wed, 21 Feb 2024 19:46:14 -0600 Subject: [PATCH] update MyMagicAI (#11263) * update * address comments --- docs/examples/llm/mymagic.ipynb | 34 +++++++++++-------- .../llama_index/llms/mymagic/base.py | 25 ++++++++++---- .../llama-index-llms-mymagic/pyproject.toml | 2 +- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/docs/examples/llm/mymagic.ipynb b/docs/examples/llm/mymagic.ipynb index df35e377d9..d1ee8a10a3 100644 --- a/docs/examples/llm/mymagic.ipynb +++ b/docs/examples/llm/mymagic.ipynb @@ -60,11 +60,12 @@ "outputs": [], "source": [ "llm = MyMagicAI(\n", - " api_key=\"your_api_key\",\n", - " storage_provider=\"your_storage_provider\", # s3, gcs\n", - " bucket_name=\"your_bucket_name\",\n", - " session=\"your_session\", # files should be located in this folder on which batch inference will be run\n", - " system_prompt=\"Answer the question succinctly\",\n", + " api_key=\"your-api-key\",\n", + " storage_provider=\"s3\", # s3, gcs\n", + " bucket_name=\"your-bucket-name\",\n", + " session=\"your-session-name\", # files should be located in this folder on which batch inference will be run\n", + " role_arn=\"your-role-arn\",\n", + " system_prompt=\"your-system-prompt\",\n", ")" ] }, @@ -75,9 +76,9 @@ "outputs": [], "source": [ "resp = llm.complete(\n", - " question=\"Summarize the document!\",\n", - " model=\"mistral7b\",\n", - " max_tokens=10, # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n", + " question=\"your-question\",\n", + " model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n", + " max_tokens=5, # number of tokens to generate, default is 10\n", ")" ] }, @@ -116,14 +117,17 @@ "source": [ "async def main():\n", " allm = MyMagicAI(\n", - " api_key=\"your_api_key\",\n", - " storage_provider=\"your_storage_provider\",\n", - " bucket_name=\"your_bucket_name\",\n", - " session=\"your_session_name\",\n", - " system_prompt=\"your_system_prompt\",\n", + " api_key=\"your-api-key\",\n", + " storage_provider=\"s3\", # s3, gcs\n", + " bucket_name=\"your-bucket-name\",\n", + " session=\"your-session-name\", # files should be located in this folder on which batch inference will be run\n", + " role_arn=\"your-role-arn\",\n", + " system_prompt=\"your-system-prompt\",\n", " )\n", " response = await allm.acomplete(\n", - " question=\"your_question\", model=\"mistral7b\", max_tokens=10\n", + " question=\"your-question\",\n", + " model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n", + " max_tokens=5, # number of tokens to generate, default is 10\n", " )\n", "\n", " print(\"Async completion response:\", response)" @@ -135,7 +139,7 @@ "metadata": {}, "outputs": [], "source": [ - "asyncio.run(main())" + "await main()" ] } ], diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py index f44a29ef4c..9c3f6f3ea3 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py @@ -23,6 +23,7 @@ class MyMagicAI(LLM): max_tokens: int = Field( default=10, description="The maximum number of tokens to generate." ) + question = Field(default="", description="The user question.") storage_provider: str = Field( default="gcs", description="The storage provider to use." ) @@ -105,11 +106,17 @@ class MyMagicAI(LLM): return resp.json() async def acomplete( - self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0 + self, + question: str, + model: Optional[str] = None, + max_tokens: Optional[int] = None, + poll_interval: float = 1.0, ) -> CompletionResponse: self.question_data["question"] = question - self.question_data["model"] = model - self.question_data["max_tokens"] = max_tokens + self.model = self.question_data["model"] = model or self.model + self.max_tokens = self.question_data["max_tokens"] = ( + max_tokens or self.max_tokens + ) task_response = await self._submit_question(self.question_data) task_id = task_response.get("task_id") @@ -120,11 +127,17 @@ class MyMagicAI(LLM): await asyncio.sleep(poll_interval) def complete( - self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0 + self, + question: str, + model: Optional[str] = None, + max_tokens: Optional[int] = None, + poll_interval: float = 1.0, ) -> CompletionResponse: self.question_data["question"] = question - self.question_data["model"] = model - self.question_data["max_tokens"] = max_tokens + self.model = self.question_data["model"] = model or self.model + self.max_tokens = self.question_data["max_tokens"] = ( + max_tokens or self.max_tokens + ) task_response = self._submit_question_sync(self.question_data) task_id = task_response.get("task_id") diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml index ca9b5679a7..f900161f1c 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml @@ -26,7 +26,7 @@ description = "llama-index llms mymagic integration" license = "MIT" name = "llama-index-llms-mymagic" readme = "README.md" -version = "0.1.0" +version = "0.1.1" [tool.poetry.dependencies] python = ">=3.8.1,<3.12" -- GitLab