From 439e1b67b1a30cb08a87bfa883fed1af9db19b81 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Wed, 21 Feb 2024 19:46:14 -0600
Subject: [PATCH] update MyMagicAI (#11263)

* update

* address comments
---
 docs/examples/llm/mymagic.ipynb               | 34 +++++++++++--------
 .../llama_index/llms/mymagic/base.py          | 25 ++++++++++----
 .../llama-index-llms-mymagic/pyproject.toml   |  2 +-
 3 files changed, 39 insertions(+), 22 deletions(-)

diff --git a/docs/examples/llm/mymagic.ipynb b/docs/examples/llm/mymagic.ipynb
index df35e377d9..d1ee8a10a3 100644
--- a/docs/examples/llm/mymagic.ipynb
+++ b/docs/examples/llm/mymagic.ipynb
@@ -60,11 +60,12 @@
    "outputs": [],
    "source": [
     "llm = MyMagicAI(\n",
-    "    api_key=\"your_api_key\",\n",
-    "    storage_provider=\"your_storage_provider\",  # s3, gcs\n",
-    "    bucket_name=\"your_bucket_name\",\n",
-    "    session=\"your_session\",  # files should be located in this folder on which batch inference will be run\n",
-    "    system_prompt=\"Answer the question succinctly\",\n",
+    "    api_key=\"your-api-key\",\n",
+    "    storage_provider=\"s3\",  # s3, gcs\n",
+    "    bucket_name=\"your-bucket-name\",\n",
+    "    session=\"your-session-name\",  # files should be located in this folder on which batch inference will be run\n",
+    "    role_arn=\"your-role-arn\",\n",
+    "    system_prompt=\"your-system-prompt\",\n",
     ")"
    ]
   },
@@ -75,9 +76,9 @@
    "outputs": [],
    "source": [
     "resp = llm.complete(\n",
-    "    question=\"Summarize the document!\",\n",
-    "    model=\"mistral7b\",\n",
-    "    max_tokens=10,  # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
+    "    question=\"your-question\",\n",
+    "    model=\"chhoose-model\",  # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
+    "    max_tokens=5,  # number of tokens to generate, default is 10\n",
     ")"
    ]
   },
@@ -116,14 +117,17 @@
    "source": [
     "async def main():\n",
     "    allm = MyMagicAI(\n",
-    "        api_key=\"your_api_key\",\n",
-    "        storage_provider=\"your_storage_provider\",\n",
-    "        bucket_name=\"your_bucket_name\",\n",
-    "        session=\"your_session_name\",\n",
-    "        system_prompt=\"your_system_prompt\",\n",
+    "        api_key=\"your-api-key\",\n",
+    "        storage_provider=\"s3\",  # s3, gcs\n",
+    "        bucket_name=\"your-bucket-name\",\n",
+    "        session=\"your-session-name\",  # files should be located in this folder on which batch inference will be run\n",
+    "        role_arn=\"your-role-arn\",\n",
+    "        system_prompt=\"your-system-prompt\",\n",
     "    )\n",
     "    response = await allm.acomplete(\n",
-    "        question=\"your_question\", model=\"mistral7b\", max_tokens=10\n",
+    "        question=\"your-question\",\n",
+    "        model=\"chhoose-model\",  # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
+    "        max_tokens=5,  # number of tokens to generate, default is 10\n",
     "    )\n",
     "\n",
     "    print(\"Async completion response:\", response)"
@@ -135,7 +139,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "asyncio.run(main())"
+    "await main()"
    ]
   }
  ],
diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py
index f44a29ef4c..9c3f6f3ea3 100644
--- a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py
@@ -23,6 +23,7 @@ class MyMagicAI(LLM):
     max_tokens: int = Field(
         default=10, description="The maximum number of tokens to generate."
     )
+    question = Field(default="", description="The user question.")
     storage_provider: str = Field(
         default="gcs", description="The storage provider to use."
     )
@@ -105,11 +106,17 @@ class MyMagicAI(LLM):
             return resp.json()
 
     async def acomplete(
-        self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0
+        self,
+        question: str,
+        model: Optional[str] = None,
+        max_tokens: Optional[int] = None,
+        poll_interval: float = 1.0,
     ) -> CompletionResponse:
         self.question_data["question"] = question
-        self.question_data["model"] = model
-        self.question_data["max_tokens"] = max_tokens
+        self.model = self.question_data["model"] = model or self.model
+        self.max_tokens = self.question_data["max_tokens"] = (
+            max_tokens or self.max_tokens
+        )
 
         task_response = await self._submit_question(self.question_data)
         task_id = task_response.get("task_id")
@@ -120,11 +127,17 @@ class MyMagicAI(LLM):
             await asyncio.sleep(poll_interval)
 
     def complete(
-        self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0
+        self,
+        question: str,
+        model: Optional[str] = None,
+        max_tokens: Optional[int] = None,
+        poll_interval: float = 1.0,
     ) -> CompletionResponse:
         self.question_data["question"] = question
-        self.question_data["model"] = model
-        self.question_data["max_tokens"] = max_tokens
+        self.model = self.question_data["model"] = model or self.model
+        self.max_tokens = self.question_data["max_tokens"] = (
+            max_tokens or self.max_tokens
+        )
 
         task_response = self._submit_question_sync(self.question_data)
         task_id = task_response.get("task_id")
diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml
index ca9b5679a7..f900161f1c 100644
--- a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml
@@ -26,7 +26,7 @@ description = "llama-index llms mymagic integration"
 license = "MIT"
 name = "llama-index-llms-mymagic"
 readme = "README.md"
-version = "0.1.0"
+version = "0.1.1"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<3.12"
-- 
GitLab