diff --git a/docs/examples/llm/mymagic.ipynb b/docs/examples/llm/mymagic.ipynb index af38d4a83547fa699d3835ffca2dd781eb645edf..81a187cbf4abc3b8c7ddbbfeaefcd9f1da4e572e 100644 --- a/docs/examples/llm/mymagic.ipynb +++ b/docs/examples/llm/mymagic.ipynb @@ -68,9 +68,18 @@ " system_prompt=\"your-system-prompt\",\n", " region=\"your-bucket-region\",\n", " return_output=False, # Whether you want MyMagic API to return the output json\n", + " input_json_file=None, # name of the input file (stored on the bucket)\n", + " structured_output=None, # json schema of the output\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: if return_output is set True above, max_tokens should be set to atleast 100 " + ] + }, { "cell_type": "code", "execution_count": null, @@ -118,18 +127,9 @@ "outputs": [], "source": [ "async def main():\n", - " allm = MyMagicAI(\n", - " api_key=\"your-api-key\",\n", - " storage_provider=\"s3\", # s3, gcs\n", - " bucket_name=\"your-bucket-name\",\n", - " session=\"your-session-name\", # files should be located in this folder on which batch inference will be run\n", - " role_arn=\"your-role-arn\",\n", - " system_prompt=\"your-system-prompt\",\n", - " region=\"your-bucket-region\",\n", - " )\n", - " response = await allm.acomplete(\n", + " response = await llm.acomplete(\n", " question=\"your-question\",\n", - " model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n", + " model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7,codellama70b, llama70b, more to come...\n", " max_tokens=5, # number of tokens to generate, default is 10\n", " )\n", "\n", diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py index 7051e10759e2fd0c93060090aaf4d07443705e3c..f1851c3f6cb9ae6253e715fbe83d1eb2aa52e309 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/llama_index/llms/mymagic/base.py @@ -51,6 +51,11 @@ class MyMagicAI(LLM): return_output: Optional[bool] = Field( False, description="Whether MyMagic API should return the output json" ) + input_json_file: Optional[str] = None + + structured_output: Optional[Dict[str, Any]] = Field( + None, description="User-defined structure for the response output" + ) def __init__( self, @@ -62,9 +67,12 @@ class MyMagicAI(LLM): role_arn: Optional[str] = None, region: Optional[str] = None, return_output: Optional[bool] = False, + input_json_file: Optional[str] = None, + structured_output: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) + self.return_output = return_output self.question_data = { "storage_provider": storage_provider, @@ -76,6 +84,8 @@ class MyMagicAI(LLM): "system_prompt": system_prompt, "region": region, "return_output": return_output, + "input_json_file": input_json_file, + "structured_output": structured_output, } @classmethod @@ -87,7 +97,9 @@ class MyMagicAI(LLM): return self.base_url_template.format(model=model) async def _submit_question(self, question_data: Dict[str, Any]) -> Dict[str, Any]: - async with httpx.AsyncClient() as client: + timeout_config = httpx.Timeout(600.0, connect=60.0) + + async with httpx.AsyncClient(timeout=timeout_config) as client: url = f"{self._construct_url(self.model)}/submit_question" resp = await client.post(url, json=question_data) resp.raise_for_status() @@ -129,6 +141,10 @@ class MyMagicAI(LLM): ) task_response = await self._submit_question(self.question_data) + + if self.return_output: + return task_response + task_id = task_response.get("task_id") while True: result = await self._get_result(task_id) @@ -150,6 +166,9 @@ class MyMagicAI(LLM): ) task_response = self._submit_question_sync(self.question_data) + if self.return_output: + return task_response + task_id = task_response.get("task_id") while True: result = self._get_result_sync(task_id) diff --git a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml index 0e7ba7cea8313e8818ee2bdfb419df90d7064c3f..ac3c1348a19849623b368c0ce4e17c9b10430c95 100644 --- a/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mymagic/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-mymagic" readme = "README.md" -version = "0.1.5" +version = "0.1.6" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"