diff --git a/docs/core_modules/model_modules/llms/modules.md b/docs/core_modules/model_modules/llms/modules.md index 616b971b9400539624265fccd2444e3e02ca3542..ffa41c8b9348c1b58a81d2ba8a94eea1d20b655d 100644 --- a/docs/core_modules/model_modules/llms/modules.md +++ b/docs/core_modules/model_modules/llms/modules.md @@ -13,6 +13,15 @@ maxdepth: 1 ``` +## AI21 + +```{toctree} +--- +maxdepth: 1 +--- +/examples/llm/ai21.ipynb +``` + ## Anthropic ```{toctree} diff --git a/docs/examples/llm/ai21.ipynb b/docs/examples/llm/ai21.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b1bd58dcf71803d7c0e163c74c60cb5ab9283458 --- /dev/null +++ b/docs/examples/llm/ai21.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AI21" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Call `complete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import AI21\n", + "\n", + "api_key = \"Your api key\"\n", + "resp = AI21(api_key=api_key).complete(\"Paul Graham is \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \" ANSI Common Lisp\" and \" Hackers and Painters.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n" + ] + } + ], + "source": [ + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Call `chat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import ChatMessage, AI21\n", + "\n", + "messages = [\n", + " ChatMessage(role=\"user\", content=\"hello there\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Arrrr, matey! How can I help ye today?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"What is your name\"),\n", + "]\n", + "\n", + "resp = AI21(api_key=api_key).chat(\n", + " messages, preamble_override=\"You are a pirate with a colorful personality\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "assistant: yer talkin' to Captain Jack Sparrow\n" + ] + } + ], + "source": [ + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import AI21\n", + "\n", + "llm = AI21(model=\"j2-mid\", api_key=api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "resp = llm.complete(\"Paul Graham is \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \" ANSI Common Lisp\" and \" Hackers and Painters.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n" + ] + } + ], + "source": [ + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set API Key at a per-instance level\n", + "If desired, you can have separate LLM instances use separate API keys." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \"Hackers and Painters\" and \"On Lisp.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calling POST https://api.ai21.com/studio/v1/j2-mid/complete failed with a non-200 response code: 401\n" + ] + }, + { + "ename": "Unauthorized", + "evalue": "Failed with http status code: 401 (Unauthorized). Details: {\"detail\":\"Forbidden: Bad or missing API token.\"}", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mUnauthorized\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb Cell 14\u001b[0m line \u001b[0;36m9\n\u001b[1;32m <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m resp \u001b[39m=\u001b[39m llm_good\u001b[39m.\u001b[39mcomplete(\u001b[39m\"\u001b[39m\u001b[39mPaul Graham is \u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mprint\u001b[39m(resp)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m resp \u001b[39m=\u001b[39m llm_bad\u001b[39m.\u001b[39;49mcomplete(\u001b[39m\"\u001b[39;49m\u001b[39mPaul Graham is \u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39mprint\u001b[39m(resp)\n", + "File \u001b[0;32m~/Desktop/projects/lindex/llama_index/llama_index/llms/base.py:312\u001b[0m, in \u001b[0;36mllm_completion_callback.<locals>.wrap.<locals>.wrapped_llm_predict\u001b[0;34m(_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[39mwith\u001b[39;00m wrapper_logic(_self) \u001b[39mas\u001b[39;00m callback_manager:\n\u001b[1;32m 303\u001b[0m event_id \u001b[39m=\u001b[39m callback_manager\u001b[39m.\u001b[39mon_event_start(\n\u001b[1;32m 304\u001b[0m CBEventType\u001b[39m.\u001b[39mLLM,\n\u001b[1;32m 305\u001b[0m payload\u001b[39m=\u001b[39m{\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 309\u001b[0m },\n\u001b[1;32m 310\u001b[0m )\n\u001b[0;32m--> 312\u001b[0m f_return_val \u001b[39m=\u001b[39m f(_self, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 313\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(f_return_val, Generator):\n\u001b[1;32m 314\u001b[0m \u001b[39m# intercept the generator and add a callback to the end\u001b[39;00m\n\u001b[1;32m 315\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mwrapped_gen\u001b[39m() \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m CompletionResponseGen:\n", + "File \u001b[0;32m~/Desktop/projects/lindex/llama_index/llama_index/llms/ai21.py:104\u001b[0m, in \u001b[0;36mAI21.complete\u001b[0;34m(self, prompt, **kwargs)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mai21\u001b[39;00m\n\u001b[1;32m 102\u001b[0m ai21\u001b[39m.\u001b[39mapi_key \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_api_key\n\u001b[0;32m--> 104\u001b[0m response \u001b[39m=\u001b[39m ai21\u001b[39m.\u001b[39;49mCompletion\u001b[39m.\u001b[39;49mexecute(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mall_kwargs, prompt\u001b[39m=\u001b[39;49mprompt)\n\u001b[1;32m 106\u001b[0m \u001b[39mreturn\u001b[39;00m CompletionResponse(\n\u001b[1;32m 107\u001b[0m text\u001b[39m=\u001b[39mresponse[\u001b[39m\"\u001b[39m\u001b[39mcompletions\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m0\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mdata\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m], raw\u001b[39m=\u001b[39mresponse\u001b[39m.\u001b[39m\u001b[39m__dict__\u001b[39m\n\u001b[1;32m 108\u001b[0m )\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/resources/nlp_task.py:22\u001b[0m, in \u001b[0;36mNLPTask.execute\u001b[0;34m(cls, **params)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_execute_sm(destination\u001b[39m=\u001b[39mdestination, params\u001b[39m=\u001b[39mparams)\n\u001b[1;32m 21\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(destination, AI21Destination):\n\u001b[0;32m---> 22\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_execute_studio_api(params)\n\u001b[1;32m 24\u001b[0m \u001b[39mraise\u001b[39;00m WrongInputTypeException(key\u001b[39m=\u001b[39mDESTINATION_KEY, expected_type\u001b[39m=\u001b[39mDestination, given_type\u001b[39m=\u001b[39m\u001b[39mtype\u001b[39m(destination))\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/completion.py:69\u001b[0m, in \u001b[0;36mCompletion._execute_studio_api\u001b[0;34m(cls, params)\u001b[0m\n\u001b[1;32m 65\u001b[0m url \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mcustom_model\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 67\u001b[0m url \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mMODULE_NAME\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[0;32m---> 69\u001b[0m \u001b[39mreturn\u001b[39;00m execute_studio_request(task_url\u001b[39m=\u001b[39;49murl, params\u001b[39m=\u001b[39;49mparams)\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/resources/execution_utils.py:11\u001b[0m, in \u001b[0;36mexecute_studio_request\u001b[0;34m(task_url, params, method)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mexecute_studio_request\u001b[39m(task_url: \u001b[39mstr\u001b[39m, params, method: \u001b[39mstr\u001b[39m \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mPOST\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m 10\u001b[0m client \u001b[39m=\u001b[39m AI21StudioClient(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams)\n\u001b[0;32m---> 11\u001b[0m \u001b[39mreturn\u001b[39;00m client\u001b[39m.\u001b[39;49mexecute_http_request(method\u001b[39m=\u001b[39;49mmethod, url\u001b[39m=\u001b[39;49mtask_url, params\u001b[39m=\u001b[39;49mparams)\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/ai21_studio_client.py:52\u001b[0m, in \u001b[0;36mAI21StudioClient.execute_http_request\u001b[0;34m(self, method, url, params, files)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mexecute_http_request\u001b[39m(\u001b[39mself\u001b[39m, method: \u001b[39mstr\u001b[39m, url: \u001b[39mstr\u001b[39m, params: Optional[Dict] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, files\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[0;32m---> 52\u001b[0m response \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhttp_client\u001b[39m.\u001b[39;49mexecute_http_request(method\u001b[39m=\u001b[39;49mmethod, url\u001b[39m=\u001b[39;49murl, params\u001b[39m=\u001b[39;49mparams, files\u001b[39m=\u001b[39;49mfiles)\n\u001b[1;32m 53\u001b[0m \u001b[39mreturn\u001b[39;00m convert_to_ai21_object(response)\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/http_client.py:84\u001b[0m, in \u001b[0;36mHttpClient.execute_http_request\u001b[0;34m(self, method, url, params, files, auth)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[39mif\u001b[39;00m response\u001b[39m.\u001b[39mstatus_code \u001b[39m!=\u001b[39m \u001b[39m200\u001b[39m:\n\u001b[1;32m 83\u001b[0m log_error(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mCalling \u001b[39m\u001b[39m{\u001b[39;00mmethod\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m failed with a non-200 response code: \u001b[39m\u001b[39m{\u001b[39;00mresponse\u001b[39m.\u001b[39mstatus_code\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[0;32m---> 84\u001b[0m handle_non_success_response(response\u001b[39m.\u001b[39;49mstatus_code, response\u001b[39m.\u001b[39;49mtext)\n\u001b[1;32m 86\u001b[0m \u001b[39mreturn\u001b[39;00m response\u001b[39m.\u001b[39mjson()\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/http_client.py:23\u001b[0m, in \u001b[0;36mhandle_non_success_response\u001b[0;34m(status_code, response_text)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[39mraise\u001b[39;00m BadRequest(details\u001b[39m=\u001b[39mresponse_text)\n\u001b[1;32m 22\u001b[0m \u001b[39mif\u001b[39;00m status_code \u001b[39m==\u001b[39m \u001b[39m401\u001b[39m:\n\u001b[0;32m---> 23\u001b[0m \u001b[39mraise\u001b[39;00m Unauthorized(details\u001b[39m=\u001b[39mresponse_text)\n\u001b[1;32m 24\u001b[0m \u001b[39mif\u001b[39;00m status_code \u001b[39m==\u001b[39m \u001b[39m422\u001b[39m:\n\u001b[1;32m 25\u001b[0m \u001b[39mraise\u001b[39;00m UnprocessableEntity(details\u001b[39m=\u001b[39mresponse_text)\n", + "\u001b[0;31mUnauthorized\u001b[0m: Failed with http status code: 401 (Unauthorized). Details: {\"detail\":\"Forbidden: Bad or missing API token.\"}" + ] + } + ], + "source": [ + "from llama_index.llms import AI21\n", + "\n", + "llm_good = AI21(api_key=api_key)\n", + "llm_bad = AI21(model=\"j2-mid\", api_key=\"BAD_KEY\")\n", + "\n", + "resp = llm_good.complete(\"Paul Graham is \")\n", + "print(resp)\n", + "\n", + "resp = llm_bad.complete(\"Paul Graham is \")\n", + "print(resp)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama-index-2x1vjWb5-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index e9c27380dfd818d0a73e96a80e1189b9921c31e0..c7b675b8d52583b3024debbefc182663baf63180 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -1,3 +1,4 @@ +from llama_index.llms.ai21 import AI21 from llama_index.llms.anthropic import Anthropic from llama_index.llms.anyscale import Anyscale from llama_index.llms.azure_openai import AzureOpenAI @@ -36,6 +37,7 @@ from llama_index.llms.replicate import Replicate from llama_index.llms.xinference import Xinference __all__ = [ + "AI21", "Anthropic", "Anyscale", "AzureOpenAI", diff --git a/llama_index/llms/ai21.py b/llama_index/llms/ai21.py new file mode 100644 index 0000000000000000000000000000000000000000..bc77c164c6c53273a64a72b8205869fb16ea3d4a --- /dev/null +++ b/llama_index/llms/ai21.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, Optional, Sequence + +from llama_index.bridge.pydantic import Field, PrivateAttr +from llama_index.callbacks import CallbackManager +from llama_index.llms.ai21_utils import ai21_model_to_context_size +from llama_index.llms.base import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, + get_from_param_or_env, +) + + +class AI21(CustomLLM): + """AI21 Labs LLM.""" + + model: str = Field(description="The AI21 model to use.") + maxTokens: int = Field(description="The maximum number of tokens to generate.") + temperature: float = Field(description="The temperature to use for sampling.") + + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the anthropic API." + ) + + _api_key = PrivateAttr() + + def __init__( + self, + api_key: Optional[str] = None, + model: Optional[str] = "j2-mid", + maxTokens: Optional[int] = 512, + temperature: Optional[float] = 0.1, + additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + ) -> None: + """Initialize params.""" + try: + import ai21 as _ + except ImportError as e: + raise ImportError( + "You must install the `ai21` package to use AI21." + "Please `pip install ai21`" + ) from e + + additional_kwargs = additional_kwargs or {} + callback_manager = callback_manager or CallbackManager([]) + + api_key = get_from_param_or_env("api_key", api_key, "AI21_API_KEY") + self._api_key = api_key + + super().__init__( + model=model, + maxTokens=maxTokens, + temperature=temperature, + additional_kwargs=additional_kwargs, + callback_manager=callback_manager, + ) + + @classmethod + def class_name(self) -> str: + """Get Class Name.""" + return "AI21_LLM" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=ai21_model_to_context_size(self.model), + num_output=self.maxTokens, + model_name=self.model, + ) + + @property + def _model_kwargs(self) -> Dict[str, Any]: + base_kwargs = { + "model": self.model, + "maxTokens": self.maxTokens, + "temperature": self.temperature, + } + return {**base_kwargs, **self.additional_kwargs} + + def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + return { + **self._model_kwargs, + **kwargs, + } + + @llm_completion_callback() + def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + all_kwargs = self._get_all_kwargs(**kwargs) + + import ai21 + + ai21.api_key = self._api_key + + response = ai21.Completion.execute(**all_kwargs, prompt=prompt) + + return CompletionResponse( + text=response["completions"][0]["data"]["text"], raw=response.__dict__ + ) + + @llm_completion_callback() + def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + raise NotImplementedError( + "AI21 does not currently support streaming completion." + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + all_kwargs = self._get_all_kwargs(**kwargs) + chat_fn = completion_to_chat_decorator(self.complete) + + return chat_fn(messages, **all_kwargs) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError("AI21 does not Currently Support Streaming Chat.") diff --git a/llama_index/llms/ai21_utils.py b/llama_index/llms/ai21_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f03a856c64ee1b80660912b2aa104b19910b13d --- /dev/null +++ b/llama_index/llms/ai21_utils.py @@ -0,0 +1,21 @@ +from typing import Union + +COMPLETE_MODELS = {"j2-light": 8191, "j2-mid": 8191, "j2-ultra": 8191} + + +def ai21_model_to_context_size(model: str) -> Union[int, None]: + """Calculate the maximum number of tokens possible to generate for a model. + + Args: + model: The modelname we want to know the context size for. + + Returns: + The maximum context size + + """ + token_limit = COMPLETE_MODELS.get(model, None) + + if token_limit is None: + raise ValueError(f"Model name {model} not found in {COMPLETE_MODELS.keys()}") + + return token_limit diff --git a/tests/llms/test_ai21.py b/tests/llms/test_ai21.py new file mode 100644 index 0000000000000000000000000000000000000000..c220bfd0261d1adb00cb7c7fb88ce65373034313 --- /dev/null +++ b/tests/llms/test_ai21.py @@ -0,0 +1,336 @@ +from typing import TYPE_CHECKING, Any, Union + +import pytest +from llama_index.llms import ChatMessage +from pytest import MonkeyPatch + +if TYPE_CHECKING: + from ai21.ai21_object import AI21Object + +try: + import ai21 + from ai21.ai21_object import construct_ai21_object +except ImportError: + ai21 = None # type: ignore + + +from llama_index.llms.ai21 import AI21 + + +def mock_completion(*args: Any, **kwargs: Any) -> Union[Any, "AI21Object"]: + return construct_ai21_object( + { + "id": "f6adacef-0e94-6353-244f-df8d38954b19", + "prompt": { + "text": "This is just a test", + "tokens": [ + { + "generatedToken": { + "token": "▁This▁is▁just", + "logprob": -13.657383918762207, + "raw_logprob": -13.657383918762207, + }, + "topTokens": None, + "textRange": {"start": 0, "end": 12}, + }, + { + "generatedToken": { + "token": "▁a▁test", + "logprob": -4.080351829528809, + "raw_logprob": -4.080351829528809, + }, + "topTokens": None, + "textRange": {"start": 12, "end": 19}, + }, + ], + }, + "completions": [ + { + "data": { + "text": "\nThis is a test to see if my text is showing up correctly.", + "tokens": [ + { + "generatedToken": { + "token": "<|newline|>", + "logprob": 0, + "raw_logprob": -0.01992332935333252, + }, + "topTokens": None, + "textRange": {"start": 0, "end": 1}, + }, + { + "generatedToken": { + "token": "▁This▁is▁a", + "logprob": -0.00014733182615600526, + "raw_logprob": -1.228371500968933, + }, + "topTokens": None, + "textRange": {"start": 1, "end": 10}, + }, + { + "generatedToken": { + "token": "▁test", + "logprob": 0, + "raw_logprob": -0.0422857291996479, + }, + "topTokens": None, + "textRange": {"start": 10, "end": 15}, + }, + { + "generatedToken": { + "token": "▁to▁see▁if", + "logprob": -0.4861462712287903, + "raw_logprob": -1.2263909578323364, + }, + "topTokens": None, + "textRange": {"start": 15, "end": 25}, + }, + { + "generatedToken": { + "token": "▁my", + "logprob": -9.536738616588991e-7, + "raw_logprob": -0.8164164423942566, + }, + "topTokens": None, + "textRange": {"start": 25, "end": 28}, + }, + { + "generatedToken": { + "token": "▁text", + "logprob": -0.003087161108851433, + "raw_logprob": -1.7130306959152222, + }, + "topTokens": None, + "textRange": {"start": 28, "end": 33}, + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -1.8836627006530762, + "raw_logprob": -0.9880049824714661, + }, + "topTokens": None, + "textRange": {"start": 33, "end": 36}, + }, + { + "generatedToken": { + "token": "▁showing▁up", + "logprob": -0.00006341733387671411, + "raw_logprob": -0.954255223274231, + }, + "topTokens": None, + "textRange": {"start": 36, "end": 47}, + }, + { + "generatedToken": { + "token": "▁correctly", + "logprob": -0.00022098960471339524, + "raw_logprob": -0.6004139184951782, + }, + "topTokens": None, + "textRange": {"start": 47, "end": 57}, + }, + { + "generatedToken": { + "token": ".", + "logprob": 0, + "raw_logprob": -0.039214372634887695, + }, + "topTokens": None, + "textRange": {"start": 57, "end": 58}, + }, + { + "generatedToken": { + "token": "<|endoftext|>", + "logprob": 0, + "raw_logprob": -0.22456447780132294, + }, + "topTokens": None, + "textRange": {"start": 58, "end": 58}, + }, + ], + }, + "finishReason": {"reason": "endoftext"}, + } + ], + } + ) + + +def mock_chat(*args: Any, **kwargs: Any) -> Union[Any, "AI21Object"]: + return construct_ai21_object( + { + "id": "f8d0cd0a-7c85-deb2-16b3-491c7ffdd4f2", + "prompt": { + "text": "user: This is just a test assistant:", + "tokens": [ + { + "generatedToken": { + "token": "▁user", + "logprob": -13.633946418762207, + "raw_logprob": -13.633946418762207, + }, + "topTokens": None, + "textRange": {"start": 0, "end": 4}, + }, + { + "generatedToken": { + "token": ":", + "logprob": -5.545032978057861, + "raw_logprob": -5.545032978057861, + }, + "topTokens": None, + "textRange": {"start": 4, "end": 5}, + }, + { + "generatedToken": { + "token": "▁This▁is▁just", + "logprob": -10.848762512207031, + "raw_logprob": -10.848762512207031, + }, + "topTokens": None, + "textRange": {"start": 5, "end": 18}, + }, + { + "generatedToken": { + "token": "▁a▁test", + "logprob": -2.0551252365112305, + "raw_logprob": -2.0551252365112305, + }, + "topTokens": None, + "textRange": {"start": 18, "end": 25}, + }, + { + "generatedToken": { + "token": "▁assistant", + "logprob": -17.020610809326172, + "raw_logprob": -17.020610809326172, + }, + "topTokens": None, + "textRange": {"start": 25, "end": 35}, + }, + { + "generatedToken": { + "token": ":", + "logprob": -12.311965942382812, + "raw_logprob": -12.311965942382812, + }, + "topTokens": None, + "textRange": {"start": 35, "end": 36}, + }, + ], + }, + "completions": [ + { + "data": { + "text": "\nassistant:\nHow can I assist you today?", + "tokens": [ + { + "generatedToken": { + "token": "<|newline|>", + "logprob": 0, + "raw_logprob": -0.02031332440674305, + }, + "topTokens": None, + "textRange": {"start": 0, "end": 1}, + }, + { + "generatedToken": { + "token": "▁assistant", + "logprob": 0, + "raw_logprob": -0.24520651996135712, + }, + "topTokens": None, + "textRange": {"start": 1, "end": 10}, + }, + { + "generatedToken": { + "token": ":", + "logprob": 0, + "raw_logprob": -0.0026112052146345377, + }, + "topTokens": None, + "textRange": {"start": 10, "end": 11}, + }, + { + "generatedToken": { + "token": "<|newline|>", + "logprob": 0, + "raw_logprob": -0.3382393717765808, + }, + "topTokens": None, + "textRange": {"start": 11, "end": 12}, + }, + { + "generatedToken": { + "token": "▁How▁can▁I", + "logprob": -0.000008106198947643861, + "raw_logprob": -1.3073582649230957, + }, + "topTokens": None, + "textRange": {"start": 12, "end": 21}, + }, + { + "generatedToken": { + "token": "▁assist▁you", + "logprob": -2.15450382232666, + "raw_logprob": -0.8163930177688599, + }, + "topTokens": None, + "textRange": {"start": 21, "end": 32}, + }, + { + "generatedToken": { + "token": "▁today", + "logprob": 0, + "raw_logprob": -0.1474292278289795, + }, + "topTokens": None, + "textRange": {"start": 32, "end": 38}, + }, + { + "generatedToken": { + "token": "?", + "logprob": 0, + "raw_logprob": -0.011986607685685158, + }, + "topTokens": None, + "textRange": {"start": 38, "end": 39}, + }, + { + "generatedToken": { + "token": "<|endoftext|>", + "logprob": -1.1920928244535389e-7, + "raw_logprob": -0.2295214682817459, + }, + "topTokens": None, + "textRange": {"start": 39, "end": 39}, + }, + ], + }, + "finishReason": {"reason": "endoftext"}, + } + ], + } + ) + + +@pytest.mark.skipif(ai21 is None, reason="ai21 not installed") +def test_completion_model_basic(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr("ai21.Completion.execute", mock_completion) + + mock_api_key = "fake_key" + llm = AI21(model="j2-mid", api_key=mock_api_key) + + test_prompt = "This is just a test" + response = llm.complete(test_prompt) + assert ( + response.text == "\nThis is a test to see if my text is showing up correctly." + ) + + monkeypatch.setattr("ai21.Completion.execute", mock_chat) + + message = ChatMessage(role="user", content=test_prompt) + chat_response = llm.chat([message]) + print(chat_response.message.content) + assert chat_response.message.content == "\nassistant:\nHow can I assist you today?"