diff --git a/docs/examples/llm/maritalk.ipynb b/docs/examples/llm/maritalk.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b5d18c895955efbe21a36b0a47ca06a6764d2575 --- /dev/null +++ b/docs/examples/llm/maritalk.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/llm/maritalk.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n", + "\n", + "# Maritalk\n", + "\n", + "## Introduction\n", + "\n", + "MariTalk is an assistant developed by the Brazilian company [Maritaca AI](www.maritaca.ai).\n", + "MariTalk is based on language models that have been specially trained to understand Portuguese well.\n", + "\n", + "This notebook demonstrates how to use MariTalk with llama-index through a simple example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-maritalk\n", + "!pip install llama-index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API Key\n", + "You will need an API key that can be obtained from chat.maritaca.ai (\"Chaves da API\" section)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.llms import ChatMessage\n", + "from llama_index.llms.maritalk import Maritalk\n", + "\n", + "# To customize your API key, do this\n", + "# otherwise it will lookup MARITALK_API_KEY from your env variable\n", + "# llm = Maritalk(api_key=\"<your_maritalk_api_key>\")\n", + "\n", + "llm = Maritalk()\n", + "\n", + "# Call chat with a list of messages\n", + "messages = [\n", + " ChatMessage(\n", + " role=\"system\",\n", + " content=\"You are an assistant specialized in suggesting pet names. Given the animal, you must suggest 4 names.\",\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I have a dog.\"),\n", + "]\n", + "\n", + "response = llm.chat(messages)\n", + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### Few-shot examples\n", + "\n", + "We recommend using the `llm.complete()` method when using the model with few-shot examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"\"\"Classifique a resenha de filme como \"positiva\" ou \"negativa\".\n", + "\n", + "Resenha: Gostei muito do filme, é o melhor do ano!\n", + "Classe: positiva\n", + "\n", + "Resenha: O filme deixa muito a desejar.\n", + "Classe: negativa\n", + "\n", + "Resenha: Apesar de longo, valeu o ingresso..\n", + "Classe:\"\"\"\n", + "\n", + "response = llm.complete(prompt, stopping_tokens=[\"\\n\"])\n", + "print(response)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/module_guides/models/llms/modules.md b/docs/module_guides/models/llms/modules.md index 7221308b59f5c80e1f10417f65d9020ff1bd9aaf..71b6f3fc8cf1e125163f9515562d9a2a2ef5865a 100644 --- a/docs/module_guides/models/llms/modules.md +++ b/docs/module_guides/models/llms/modules.md @@ -150,6 +150,15 @@ maxdepth: 1 /examples/llm/localai.ipynb ``` +## MariTalk + +```{toctree} +--- +maxdepth: 1 +--- +/examples/llm/maritalk.ipynb +``` + ## MistralAI ```{toctree} diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/.gitignore b/llama-index-integrations/llms/llama-index-llms-maritalk/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..990c18de229088f55c6c514fd0f2d49981d1b0e7 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/BUILD b/llama-index-integrations/llms/llama-index-llms-maritalk/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0896ca890d8bffd60a44fa824f8d57fecd73ee53 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/Makefile b/llama-index-integrations/llms/llama-index-llms-maritalk/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b9eab05aa370629a4a3de75df3ff64cd53887b68 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/README.md b/llama-index-integrations/llms/llama-index-llms-maritalk/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf3413f76d27ba4b00e98c5845a8bfd4f4f44bef --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/README.md @@ -0,0 +1 @@ +# LlamaIndex Llms Integration: Maritalk diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/BUILD b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db46e8d6c978c67e301dd6c47bee08c1b3fd141c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/__init__.py b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f690c9aac6c4ac689786927708425782729d8ec0 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/__init__.py @@ -0,0 +1,4 @@ +from llama_index.llms.maritalk.base import Maritalk + + +__all__ = ["Maritalk"] diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5c48794561c6bce6a4aad8642c2f5f92c787dfe4 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py @@ -0,0 +1,195 @@ +from typing import Any, Optional, Sequence +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.llms.llm import LLM +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +import requests +import os + + +class Maritalk(LLM): + api_key: Optional[str] = Field(default=None, description="Your MariTalk API key.") + temperature: float = Field( + default=0.7, + gt=0.0, + lt=1.0, + description="Run inference with this temperature. Must be in the" + "closed interval [0.0, 1.0].", + ) + max_tokens: int = Field( + default=512, + gt=0, + description="The maximum number of tokens to" "generate in the reply.", + ) + do_sample: bool = Field( + default=True, + description="Whether or not to use sampling; use `True` to enable.", + ) + top_p: float = Field( + default=0.95, + gt=0.0, + lt=1.0, + description="Nucleus sampling parameter controlling the size of" + " the probability mass considered for sampling.", + ) + system_message_workaround: bool = Field( + default=True, + description="Whether to include a workaround for system" + " message by adding it as a user message.", + ) + + _endpoint: str = PrivateAttr("https://chat.maritaca.ai/api/chat/inference") + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + # If an API key is not provided during instantiation, + # fall back to the MARITALK_API_KEY environment variable + self.api_key = self.api_key or os.getenv("MARITALK_API_KEY") + if not self.api_key: + raise ValueError( + "An API key must be provided or set in the " + "'MARITALK_API_KEY' environment variable." + ) + + @classmethod + def class_name(cls) -> str: + return "Maritalk" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + model_name="maritalk", + context_window=self.max_tokens, + is_chat_model=True, + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + # Prepare the data payload for the Maritalk API + formatted_messages = [ + { + "role": "user" if msg.role == MessageRole.USER else "assistant", + "content": msg.content, + } + for msg in messages + ] + + data = { + "messages": formatted_messages, + "do_sample": self.do_sample, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + } + + # Update data payload with additional kwargs if any + data.update(kwargs) + + headers = {"authorization": f"Key {self.api_key}"} + + response = requests.post(self._endpoint, json=data, headers=headers) + if response.status_code == 429: + return ChatResponse( + message=ChatMessage( + role=MessageRole.SYSTEM, + content="Rate limited, please try again soon", + ), + raw=response.text, + ) + elif response.ok: + answer = response.json()["answer"] + return ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=answer), + raw=response.json(), + ) + else: + response.raise_for_status() # noqa: RET503 + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + # Prepare the data payload for the Maritalk API + data = { + "messages": prompt, + "do_sample": self.do_sample, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "chat_mode": False, + } + + # Update data payload with additional kwargs if any + data.update(kwargs) + + headers = {"authorization": f"Key {self.api_key}"} + + response = requests.post(self._endpoint, json=data, headers=headers) + if response.status_code == 429: + return CompletionResponse( + text="Rate limited, please try again soon", + raw=response.text, + ) + elif response.ok: + answer = response.json()["answer"] + return CompletionResponse( + text=answer, + raw=response.json(), + ) + else: + response.raise_for_status() # noqa: RET503 + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError( + "Maritalk does not currently support streaming completion." + ) + + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + raise NotImplementedError( + "Maritalk does not currently support streaming completion." + ) + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + return self.chat(messages, **kwargs) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + return self.complete(prompt, formatted, **kwargs) + + @llm_chat_callback() + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError( + "Maritalk does not currently support streaming completion." + ) + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError( + "Maritalk does not currently support streaming completion." + ) diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6c1909dbc5827df9b0933007dc4f7267633d107f --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.maritalk" + +[tool.llamahub.class_authors] +Maritalk = "rodrigo-f-nogueira" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Your Name <you@example.com>"] +description = "llama-index llms maritalk integration" +license = "MIT" +name = "llama-index-llms-maritalk" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<3.12" +llama-index-core = "^0.10.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0"