From e50c08737ef4bee7651c900d1388430ef339a712 Mon Sep 17 00:00:00 2001
From: Amit Kenkre <amitnmf@gmail.com>
Date: Sun, 22 Oct 2023 10:32:36 +0530
Subject: [PATCH] Add ai21 llm (#8233)

* Added AI21 Labs LLM

* added AI21 to __init__.py for llms

* Added tests for AI21 llm

* Added AI21 llm Example notebook

* Added Link to AI21 example notebook in docs

* bugfix: linting errors when ai21 is not installed

* fix lint

* wip

---------

Co-authored-by: Simon Suo <simonsdsuo@gmail.com>
---
 .../model_modules/llms/modules.md             |   9 +
 docs/examples/llm/ai21.ipynb                  | 226 ++++++++++++
 llama_index/llms/__init__.py                  |   2 +
 llama_index/llms/ai21.py                      | 127 +++++++
 llama_index/llms/ai21_utils.py                |  21 ++
 tests/llms/test_ai21.py                       | 336 ++++++++++++++++++
 6 files changed, 721 insertions(+)
 create mode 100644 docs/examples/llm/ai21.ipynb
 create mode 100644 llama_index/llms/ai21.py
 create mode 100644 llama_index/llms/ai21_utils.py
 create mode 100644 tests/llms/test_ai21.py

diff --git a/docs/core_modules/model_modules/llms/modules.md b/docs/core_modules/model_modules/llms/modules.md
index 616b971b94..ffa41c8b93 100644
--- a/docs/core_modules/model_modules/llms/modules.md
+++ b/docs/core_modules/model_modules/llms/modules.md
@@ -13,6 +13,15 @@ maxdepth: 1
 
 ```
 
+## AI21
+
+```{toctree}
+---
+maxdepth: 1
+---
+/examples/llm/ai21.ipynb
+```
+
 ## Anthropic
 
 ```{toctree}
diff --git a/docs/examples/llm/ai21.ipynb b/docs/examples/llm/ai21.ipynb
new file mode 100644
index 0000000000..b1bd58dcf7
--- /dev/null
+++ b/docs/examples/llm/ai21.ipynb
@@ -0,0 +1,226 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# AI21"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Basic Usage"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Call `complete` with a prompt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms import AI21\n",
+    "\n",
+    "api_key = \"Your api key\"\n",
+    "resp = AI21(api_key=api_key).complete(\"Paul Graham is \")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \" ANSI Common Lisp\" and \" Hackers and Painters.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(resp)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Call `chat` with a list of messages"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms import ChatMessage, AI21\n",
+    "\n",
+    "messages = [\n",
+    "    ChatMessage(role=\"user\", content=\"hello there\"),\n",
+    "    ChatMessage(\n",
+    "        role=\"assistant\", content=\"Arrrr, matey! How can I help ye today?\"\n",
+    "    ),\n",
+    "    ChatMessage(role=\"user\", content=\"What is your name\"),\n",
+    "]\n",
+    "\n",
+    "resp = AI21(api_key=api_key).chat(\n",
+    "    messages, preamble_override=\"You are a pirate with a colorful personality\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "assistant: yer talkin' to Captain Jack Sparrow\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(resp)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Configure Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from llama_index.llms import AI21\n",
+    "\n",
+    "llm = AI21(model=\"j2-mid\", api_key=api_key)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "resp = llm.complete(\"Paul Graham is \")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \" ANSI Common Lisp\" and \" Hackers and Painters.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(resp)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Set API Key at a per-instance level\n",
+    "If desired, you can have separate LLM instances use separate API keys."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "an American computer scientist, essayist, and venture capitalist. He is best known for his work on Lisp, programming language design, and entrepreneurship. Graham has written several books on these topics, including \"Hackers and Painters\" and \"On Lisp.\" He is also the co-founder of Y Combinator, a venture capital firm that invests in early-stage technology companies.\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Calling POST https://api.ai21.com/studio/v1/j2-mid/complete failed with a non-200 response code: 401\n"
+     ]
+    },
+    {
+     "ename": "Unauthorized",
+     "evalue": "Failed with http status code: 401 (Unauthorized). Details: {\"detail\":\"Forbidden: Bad or missing API token.\"}",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mUnauthorized\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[1;32m/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb Cell 14\u001b[0m line \u001b[0;36m9\n\u001b[1;32m      <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m resp \u001b[39m=\u001b[39m llm_good\u001b[39m.\u001b[39mcomplete(\u001b[39m\"\u001b[39m\u001b[39mPaul Graham is \u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m      <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mprint\u001b[39m(resp)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m resp \u001b[39m=\u001b[39m llm_bad\u001b[39m.\u001b[39;49mcomplete(\u001b[39m\"\u001b[39;49m\u001b[39mPaul Graham is \u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/amit/Desktop/projects/lindex/llama_index/docs/examples/llm/ai21.ipynb#X42sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39mprint\u001b[39m(resp)\n",
+      "File \u001b[0;32m~/Desktop/projects/lindex/llama_index/llama_index/llms/base.py:312\u001b[0m, in \u001b[0;36mllm_completion_callback.<locals>.wrap.<locals>.wrapped_llm_predict\u001b[0;34m(_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    302\u001b[0m \u001b[39mwith\u001b[39;00m wrapper_logic(_self) \u001b[39mas\u001b[39;00m callback_manager:\n\u001b[1;32m    303\u001b[0m     event_id \u001b[39m=\u001b[39m callback_manager\u001b[39m.\u001b[39mon_event_start(\n\u001b[1;32m    304\u001b[0m         CBEventType\u001b[39m.\u001b[39mLLM,\n\u001b[1;32m    305\u001b[0m         payload\u001b[39m=\u001b[39m{\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    309\u001b[0m         },\n\u001b[1;32m    310\u001b[0m     )\n\u001b[0;32m--> 312\u001b[0m     f_return_val \u001b[39m=\u001b[39m f(_self, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    313\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(f_return_val, Generator):\n\u001b[1;32m    314\u001b[0m         \u001b[39m# intercept the generator and add a callback to the end\u001b[39;00m\n\u001b[1;32m    315\u001b[0m         \u001b[39mdef\u001b[39;00m \u001b[39mwrapped_gen\u001b[39m() \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m CompletionResponseGen:\n",
+      "File \u001b[0;32m~/Desktop/projects/lindex/llama_index/llama_index/llms/ai21.py:104\u001b[0m, in \u001b[0;36mAI21.complete\u001b[0;34m(self, prompt, **kwargs)\u001b[0m\n\u001b[1;32m    100\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mai21\u001b[39;00m\n\u001b[1;32m    102\u001b[0m ai21\u001b[39m.\u001b[39mapi_key \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_api_key\n\u001b[0;32m--> 104\u001b[0m response \u001b[39m=\u001b[39m ai21\u001b[39m.\u001b[39;49mCompletion\u001b[39m.\u001b[39;49mexecute(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mall_kwargs, prompt\u001b[39m=\u001b[39;49mprompt)\n\u001b[1;32m    106\u001b[0m \u001b[39mreturn\u001b[39;00m CompletionResponse(\n\u001b[1;32m    107\u001b[0m     text\u001b[39m=\u001b[39mresponse[\u001b[39m\"\u001b[39m\u001b[39mcompletions\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m0\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mdata\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m], raw\u001b[39m=\u001b[39mresponse\u001b[39m.\u001b[39m\u001b[39m__dict__\u001b[39m\n\u001b[1;32m    108\u001b[0m )\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/resources/nlp_task.py:22\u001b[0m, in \u001b[0;36mNLPTask.execute\u001b[0;34m(cls, **params)\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_execute_sm(destination\u001b[39m=\u001b[39mdestination, params\u001b[39m=\u001b[39mparams)\n\u001b[1;32m     21\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(destination, AI21Destination):\n\u001b[0;32m---> 22\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_execute_studio_api(params)\n\u001b[1;32m     24\u001b[0m \u001b[39mraise\u001b[39;00m WrongInputTypeException(key\u001b[39m=\u001b[39mDESTINATION_KEY, expected_type\u001b[39m=\u001b[39mDestination, given_type\u001b[39m=\u001b[39m\u001b[39mtype\u001b[39m(destination))\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/completion.py:69\u001b[0m, in \u001b[0;36mCompletion._execute_studio_api\u001b[0;34m(cls, params)\u001b[0m\n\u001b[1;32m     65\u001b[0m     url \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mcustom_model\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m     67\u001b[0m url \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mMODULE_NAME\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[0;32m---> 69\u001b[0m \u001b[39mreturn\u001b[39;00m execute_studio_request(task_url\u001b[39m=\u001b[39;49murl, params\u001b[39m=\u001b[39;49mparams)\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/modules/resources/execution_utils.py:11\u001b[0m, in \u001b[0;36mexecute_studio_request\u001b[0;34m(task_url, params, method)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mexecute_studio_request\u001b[39m(task_url: \u001b[39mstr\u001b[39m, params, method: \u001b[39mstr\u001b[39m \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mPOST\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m     10\u001b[0m     client \u001b[39m=\u001b[39m AI21StudioClient(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams)\n\u001b[0;32m---> 11\u001b[0m     \u001b[39mreturn\u001b[39;00m client\u001b[39m.\u001b[39;49mexecute_http_request(method\u001b[39m=\u001b[39;49mmethod, url\u001b[39m=\u001b[39;49mtask_url, params\u001b[39m=\u001b[39;49mparams)\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/ai21_studio_client.py:52\u001b[0m, in \u001b[0;36mAI21StudioClient.execute_http_request\u001b[0;34m(self, method, url, params, files)\u001b[0m\n\u001b[1;32m     51\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mexecute_http_request\u001b[39m(\u001b[39mself\u001b[39m, method: \u001b[39mstr\u001b[39m, url: \u001b[39mstr\u001b[39m, params: Optional[Dict] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, files\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[0;32m---> 52\u001b[0m     response \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhttp_client\u001b[39m.\u001b[39;49mexecute_http_request(method\u001b[39m=\u001b[39;49mmethod, url\u001b[39m=\u001b[39;49murl, params\u001b[39m=\u001b[39;49mparams, files\u001b[39m=\u001b[39;49mfiles)\n\u001b[1;32m     53\u001b[0m     \u001b[39mreturn\u001b[39;00m convert_to_ai21_object(response)\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/http_client.py:84\u001b[0m, in \u001b[0;36mHttpClient.execute_http_request\u001b[0;34m(self, method, url, params, files, auth)\u001b[0m\n\u001b[1;32m     82\u001b[0m \u001b[39mif\u001b[39;00m response\u001b[39m.\u001b[39mstatus_code \u001b[39m!=\u001b[39m \u001b[39m200\u001b[39m:\n\u001b[1;32m     83\u001b[0m     log_error(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mCalling \u001b[39m\u001b[39m{\u001b[39;00mmethod\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m failed with a non-200 response code: \u001b[39m\u001b[39m{\u001b[39;00mresponse\u001b[39m.\u001b[39mstatus_code\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[0;32m---> 84\u001b[0m     handle_non_success_response(response\u001b[39m.\u001b[39;49mstatus_code, response\u001b[39m.\u001b[39;49mtext)\n\u001b[1;32m     86\u001b[0m \u001b[39mreturn\u001b[39;00m response\u001b[39m.\u001b[39mjson()\n",
+      "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-2x1vjWb5-py3.10/lib/python3.10/site-packages/ai21/http_client.py:23\u001b[0m, in \u001b[0;36mhandle_non_success_response\u001b[0;34m(status_code, response_text)\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[39mraise\u001b[39;00m BadRequest(details\u001b[39m=\u001b[39mresponse_text)\n\u001b[1;32m     22\u001b[0m \u001b[39mif\u001b[39;00m status_code \u001b[39m==\u001b[39m \u001b[39m401\u001b[39m:\n\u001b[0;32m---> 23\u001b[0m     \u001b[39mraise\u001b[39;00m Unauthorized(details\u001b[39m=\u001b[39mresponse_text)\n\u001b[1;32m     24\u001b[0m \u001b[39mif\u001b[39;00m status_code \u001b[39m==\u001b[39m \u001b[39m422\u001b[39m:\n\u001b[1;32m     25\u001b[0m     \u001b[39mraise\u001b[39;00m UnprocessableEntity(details\u001b[39m=\u001b[39mresponse_text)\n",
+      "\u001b[0;31mUnauthorized\u001b[0m: Failed with http status code: 401 (Unauthorized). Details: {\"detail\":\"Forbidden: Bad or missing API token.\"}"
+     ]
+    }
+   ],
+   "source": [
+    "from llama_index.llms import AI21\n",
+    "\n",
+    "llm_good = AI21(api_key=api_key)\n",
+    "llm_bad = AI21(model=\"j2-mid\", api_key=\"BAD_KEY\")\n",
+    "\n",
+    "resp = llm_good.complete(\"Paul Graham is \")\n",
+    "print(resp)\n",
+    "\n",
+    "resp = llm_bad.complete(\"Paul Graham is \")\n",
+    "print(resp)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "llama-index-2x1vjWb5-py3.10",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py
index e9c27380df..c7b675b8d5 100644
--- a/llama_index/llms/__init__.py
+++ b/llama_index/llms/__init__.py
@@ -1,3 +1,4 @@
+from llama_index.llms.ai21 import AI21
 from llama_index.llms.anthropic import Anthropic
 from llama_index.llms.anyscale import Anyscale
 from llama_index.llms.azure_openai import AzureOpenAI
@@ -36,6 +37,7 @@ from llama_index.llms.replicate import Replicate
 from llama_index.llms.xinference import Xinference
 
 __all__ = [
+    "AI21",
     "Anthropic",
     "Anyscale",
     "AzureOpenAI",
diff --git a/llama_index/llms/ai21.py b/llama_index/llms/ai21.py
new file mode 100644
index 0000000000..bc77c164c6
--- /dev/null
+++ b/llama_index/llms/ai21.py
@@ -0,0 +1,127 @@
+from typing import Any, Dict, Optional, Sequence
+
+from llama_index.bridge.pydantic import Field, PrivateAttr
+from llama_index.callbacks import CallbackManager
+from llama_index.llms.ai21_utils import ai21_model_to_context_size
+from llama_index.llms.base import (
+    ChatMessage,
+    ChatResponse,
+    ChatResponseGen,
+    CompletionResponse,
+    CompletionResponseGen,
+    LLMMetadata,
+    llm_chat_callback,
+    llm_completion_callback,
+)
+from llama_index.llms.custom import CustomLLM
+from llama_index.llms.generic_utils import (
+    completion_to_chat_decorator,
+    get_from_param_or_env,
+)
+
+
+class AI21(CustomLLM):
+    """AI21 Labs LLM."""
+
+    model: str = Field(description="The AI21 model to use.")
+    maxTokens: int = Field(description="The maximum number of tokens to generate.")
+    temperature: float = Field(description="The temperature to use for sampling.")
+
+    additional_kwargs: Dict[str, Any] = Field(
+        default_factory=dict, description="Additional kwargs for the anthropic API."
+    )
+
+    _api_key = PrivateAttr()
+
+    def __init__(
+        self,
+        api_key: Optional[str] = None,
+        model: Optional[str] = "j2-mid",
+        maxTokens: Optional[int] = 512,
+        temperature: Optional[float] = 0.1,
+        additional_kwargs: Optional[Dict[str, Any]] = None,
+        callback_manager: Optional[CallbackManager] = None,
+    ) -> None:
+        """Initialize params."""
+        try:
+            import ai21 as _
+        except ImportError as e:
+            raise ImportError(
+                "You must install the `ai21` package to use AI21."
+                "Please `pip install ai21`"
+            ) from e
+
+        additional_kwargs = additional_kwargs or {}
+        callback_manager = callback_manager or CallbackManager([])
+
+        api_key = get_from_param_or_env("api_key", api_key, "AI21_API_KEY")
+        self._api_key = api_key
+
+        super().__init__(
+            model=model,
+            maxTokens=maxTokens,
+            temperature=temperature,
+            additional_kwargs=additional_kwargs,
+            callback_manager=callback_manager,
+        )
+
+    @classmethod
+    def class_name(self) -> str:
+        """Get Class Name."""
+        return "AI21_LLM"
+
+    @property
+    def metadata(self) -> LLMMetadata:
+        return LLMMetadata(
+            context_window=ai21_model_to_context_size(self.model),
+            num_output=self.maxTokens,
+            model_name=self.model,
+        )
+
+    @property
+    def _model_kwargs(self) -> Dict[str, Any]:
+        base_kwargs = {
+            "model": self.model,
+            "maxTokens": self.maxTokens,
+            "temperature": self.temperature,
+        }
+        return {**base_kwargs, **self.additional_kwargs}
+
+    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
+        return {
+            **self._model_kwargs,
+            **kwargs,
+        }
+
+    @llm_completion_callback()
+    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
+        all_kwargs = self._get_all_kwargs(**kwargs)
+
+        import ai21
+
+        ai21.api_key = self._api_key
+
+        response = ai21.Completion.execute(**all_kwargs, prompt=prompt)
+
+        return CompletionResponse(
+            text=response["completions"][0]["data"]["text"], raw=response.__dict__
+        )
+
+    @llm_completion_callback()
+    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
+        raise NotImplementedError(
+            "AI21 does not currently support streaming completion."
+        )
+
+    @llm_chat_callback()
+    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
+        all_kwargs = self._get_all_kwargs(**kwargs)
+        chat_fn = completion_to_chat_decorator(self.complete)
+
+        return chat_fn(messages, **all_kwargs)
+
+    @llm_chat_callback()
+    def stream_chat(
+        self, messages: Sequence[ChatMessage], **kwargs: Any
+    ) -> ChatResponseGen:
+        raise NotImplementedError("AI21 does not Currently Support Streaming Chat.")
diff --git a/llama_index/llms/ai21_utils.py b/llama_index/llms/ai21_utils.py
new file mode 100644
index 0000000000..0f03a856c6
--- /dev/null
+++ b/llama_index/llms/ai21_utils.py
@@ -0,0 +1,21 @@
+from typing import Union
+
+COMPLETE_MODELS = {"j2-light": 8191, "j2-mid": 8191, "j2-ultra": 8191}
+
+
+def ai21_model_to_context_size(model: str) -> Union[int, None]:
+    """Calculate the maximum number of tokens possible to generate for a model.
+
+    Args:
+        model: The modelname we want to know the context size for.
+
+    Returns:
+        The maximum context size
+
+    """
+    token_limit = COMPLETE_MODELS.get(model, None)
+
+    if token_limit is None:
+        raise ValueError(f"Model name {model} not found in {COMPLETE_MODELS.keys()}")
+
+    return token_limit
diff --git a/tests/llms/test_ai21.py b/tests/llms/test_ai21.py
new file mode 100644
index 0000000000..c220bfd026
--- /dev/null
+++ b/tests/llms/test_ai21.py
@@ -0,0 +1,336 @@
+from typing import TYPE_CHECKING, Any, Union
+
+import pytest
+from llama_index.llms import ChatMessage
+from pytest import MonkeyPatch
+
+if TYPE_CHECKING:
+    from ai21.ai21_object import AI21Object
+
+try:
+    import ai21
+    from ai21.ai21_object import construct_ai21_object
+except ImportError:
+    ai21 = None  # type: ignore
+
+
+from llama_index.llms.ai21 import AI21
+
+
+def mock_completion(*args: Any, **kwargs: Any) -> Union[Any, "AI21Object"]:
+    return construct_ai21_object(
+        {
+            "id": "f6adacef-0e94-6353-244f-df8d38954b19",
+            "prompt": {
+                "text": "This is just a test",
+                "tokens": [
+                    {
+                        "generatedToken": {
+                            "token": "▁This▁is▁just",
+                            "logprob": -13.657383918762207,
+                            "raw_logprob": -13.657383918762207,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 0, "end": 12},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": "▁a▁test",
+                            "logprob": -4.080351829528809,
+                            "raw_logprob": -4.080351829528809,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 12, "end": 19},
+                    },
+                ],
+            },
+            "completions": [
+                {
+                    "data": {
+                        "text": "\nThis is a test to see if my text is showing up correctly.",
+                        "tokens": [
+                            {
+                                "generatedToken": {
+                                    "token": "<|newline|>",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.01992332935333252,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 0, "end": 1},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁This▁is▁a",
+                                    "logprob": -0.00014733182615600526,
+                                    "raw_logprob": -1.228371500968933,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 1, "end": 10},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁test",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.0422857291996479,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 10, "end": 15},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁to▁see▁if",
+                                    "logprob": -0.4861462712287903,
+                                    "raw_logprob": -1.2263909578323364,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 15, "end": 25},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁my",
+                                    "logprob": -9.536738616588991e-7,
+                                    "raw_logprob": -0.8164164423942566,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 25, "end": 28},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁text",
+                                    "logprob": -0.003087161108851433,
+                                    "raw_logprob": -1.7130306959152222,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 28, "end": 33},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁is",
+                                    "logprob": -1.8836627006530762,
+                                    "raw_logprob": -0.9880049824714661,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 33, "end": 36},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁showing▁up",
+                                    "logprob": -0.00006341733387671411,
+                                    "raw_logprob": -0.954255223274231,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 36, "end": 47},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁correctly",
+                                    "logprob": -0.00022098960471339524,
+                                    "raw_logprob": -0.6004139184951782,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 47, "end": 57},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": ".",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.039214372634887695,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 57, "end": 58},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "<|endoftext|>",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.22456447780132294,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 58, "end": 58},
+                            },
+                        ],
+                    },
+                    "finishReason": {"reason": "endoftext"},
+                }
+            ],
+        }
+    )
+
+
+def mock_chat(*args: Any, **kwargs: Any) -> Union[Any, "AI21Object"]:
+    return construct_ai21_object(
+        {
+            "id": "f8d0cd0a-7c85-deb2-16b3-491c7ffdd4f2",
+            "prompt": {
+                "text": "user: This is just a test assistant:",
+                "tokens": [
+                    {
+                        "generatedToken": {
+                            "token": "▁user",
+                            "logprob": -13.633946418762207,
+                            "raw_logprob": -13.633946418762207,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 0, "end": 4},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": ":",
+                            "logprob": -5.545032978057861,
+                            "raw_logprob": -5.545032978057861,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 4, "end": 5},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": "▁This▁is▁just",
+                            "logprob": -10.848762512207031,
+                            "raw_logprob": -10.848762512207031,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 5, "end": 18},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": "▁a▁test",
+                            "logprob": -2.0551252365112305,
+                            "raw_logprob": -2.0551252365112305,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 18, "end": 25},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": "▁assistant",
+                            "logprob": -17.020610809326172,
+                            "raw_logprob": -17.020610809326172,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 25, "end": 35},
+                    },
+                    {
+                        "generatedToken": {
+                            "token": ":",
+                            "logprob": -12.311965942382812,
+                            "raw_logprob": -12.311965942382812,
+                        },
+                        "topTokens": None,
+                        "textRange": {"start": 35, "end": 36},
+                    },
+                ],
+            },
+            "completions": [
+                {
+                    "data": {
+                        "text": "\nassistant:\nHow can I assist you today?",
+                        "tokens": [
+                            {
+                                "generatedToken": {
+                                    "token": "<|newline|>",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.02031332440674305,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 0, "end": 1},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁assistant",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.24520651996135712,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 1, "end": 10},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": ":",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.0026112052146345377,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 10, "end": 11},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "<|newline|>",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.3382393717765808,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 11, "end": 12},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁How▁can▁I",
+                                    "logprob": -0.000008106198947643861,
+                                    "raw_logprob": -1.3073582649230957,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 12, "end": 21},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁assist▁you",
+                                    "logprob": -2.15450382232666,
+                                    "raw_logprob": -0.8163930177688599,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 21, "end": 32},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "▁today",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.1474292278289795,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 32, "end": 38},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "?",
+                                    "logprob": 0,
+                                    "raw_logprob": -0.011986607685685158,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 38, "end": 39},
+                            },
+                            {
+                                "generatedToken": {
+                                    "token": "<|endoftext|>",
+                                    "logprob": -1.1920928244535389e-7,
+                                    "raw_logprob": -0.2295214682817459,
+                                },
+                                "topTokens": None,
+                                "textRange": {"start": 39, "end": 39},
+                            },
+                        ],
+                    },
+                    "finishReason": {"reason": "endoftext"},
+                }
+            ],
+        }
+    )
+
+
+@pytest.mark.skipif(ai21 is None, reason="ai21 not installed")
+def test_completion_model_basic(monkeypatch: MonkeyPatch) -> None:
+    monkeypatch.setattr("ai21.Completion.execute", mock_completion)
+
+    mock_api_key = "fake_key"
+    llm = AI21(model="j2-mid", api_key=mock_api_key)
+
+    test_prompt = "This is just a test"
+    response = llm.complete(test_prompt)
+    assert (
+        response.text == "\nThis is a test to see if my text is showing up correctly."
+    )
+
+    monkeypatch.setattr("ai21.Completion.execute", mock_chat)
+
+    message = ChatMessage(role="user", content=test_prompt)
+    chat_response = llm.chat([message])
+    print(chat_response.message.content)
+    assert chat_response.message.content == "\nassistant:\nHow can I assist you today?"
-- 
GitLab