Skip to content
Snippets Groups Projects
Unverified Commit 1106a9a9 authored by Rodrigo Nogueira's avatar Rodrigo Nogueira Committed by GitHub
Browse files

add maritalk (#10925)

parent 5fb18fc7
No related branches found
No related tags found
No related merge requests found
Showing
with 561 additions and 0 deletions
%% Cell type:markdown id: tags:
<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/llm/maritalk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
# Maritalk
## Introduction
MariTalk is an assistant developed by the Brazilian company [Maritaca AI](www.maritaca.ai).
MariTalk is based on language models that have been specially trained to understand Portuguese well.
This notebook demonstrates how to use MariTalk with llama-index through a simple example.
%% Cell type:markdown id: tags:
## Installation
If you're opening this Notebook on colab, you will probably need to install LlamaIndex.
%% Cell type:code id: tags:
```
%pip install llama-index-llms-maritalk
!pip install llama-index
```
%% Cell type:markdown id: tags:
## API Key
You will need an API key that can be obtained from chat.maritaca.ai ("Chaves da API" section).
%% Cell type:markdown id: tags:
## Usage
### Chat
%% Cell type:code id: tags:
```
from llama_index.core.llms import ChatMessage
from llama_index.llms.maritalk import Maritalk
# To customize your API key, do this
# otherwise it will lookup MARITALK_API_KEY from your env variable
# llm = Maritalk(api_key="<your_maritalk_api_key>")
llm = Maritalk()
# Call chat with a list of messages
messages = [
ChatMessage(
role="system",
content="You are an assistant specialized in suggesting pet names. Given the animal, you must suggest 4 names.",
),
ChatMessage(role="user", content="I have a dog."),
]
response = llm.chat(messages)
print(response)
```
%% Cell type:code id: tags:
```
### Few-shot examples
We recommend using the `llm.complete()` method when using the model with few-shot examples
```
%% Cell type:code id: tags:
```
prompt = """Classifique a resenha de filme como "positiva" ou "negativa".
Resenha: Gostei muito do filme, é o melhor do ano!
Classe: positiva
Resenha: O filme deixa muito a desejar.
Classe: negativa
Resenha: Apesar de longo, valeu o ingresso..
Classe:"""
response = llm.complete(prompt, stopping_tokens=["\n"])
print(response)
```
......@@ -150,6 +150,15 @@ maxdepth: 1
/examples/llm/localai.ipynb
```
## MariTalk
```{toctree}
---
maxdepth: 1
---
/examples/llm/maritalk.ipynb
```
## MistralAI
```{toctree}
......
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Llms Integration: Maritalk
python_sources()
from llama_index.llms.maritalk.base import Maritalk
__all__ = ["Maritalk"]
from typing import Any, Optional, Sequence
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.llms.llm import LLM
from llama_index.core.llms.callbacks import (
llm_chat_callback,
llm_completion_callback,
)
import requests
import os
class Maritalk(LLM):
api_key: Optional[str] = Field(default=None, description="Your MariTalk API key.")
temperature: float = Field(
default=0.7,
gt=0.0,
lt=1.0,
description="Run inference with this temperature. Must be in the"
"closed interval [0.0, 1.0].",
)
max_tokens: int = Field(
default=512,
gt=0,
description="The maximum number of tokens to" "generate in the reply.",
)
do_sample: bool = Field(
default=True,
description="Whether or not to use sampling; use `True` to enable.",
)
top_p: float = Field(
default=0.95,
gt=0.0,
lt=1.0,
description="Nucleus sampling parameter controlling the size of"
" the probability mass considered for sampling.",
)
system_message_workaround: bool = Field(
default=True,
description="Whether to include a workaround for system"
" message by adding it as a user message.",
)
_endpoint: str = PrivateAttr("https://chat.maritaca.ai/api/chat/inference")
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
# If an API key is not provided during instantiation,
# fall back to the MARITALK_API_KEY environment variable
self.api_key = self.api_key or os.getenv("MARITALK_API_KEY")
if not self.api_key:
raise ValueError(
"An API key must be provided or set in the "
"'MARITALK_API_KEY' environment variable."
)
@classmethod
def class_name(cls) -> str:
return "Maritalk"
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
model_name="maritalk",
context_window=self.max_tokens,
is_chat_model=True,
)
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
# Prepare the data payload for the Maritalk API
formatted_messages = [
{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content,
}
for msg in messages
]
data = {
"messages": formatted_messages,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
}
# Update data payload with additional kwargs if any
data.update(kwargs)
headers = {"authorization": f"Key {self.api_key}"}
response = requests.post(self._endpoint, json=data, headers=headers)
if response.status_code == 429:
return ChatResponse(
message=ChatMessage(
role=MessageRole.SYSTEM,
content="Rate limited, please try again soon",
),
raw=response.text,
)
elif response.ok:
answer = response.json()["answer"]
return ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=answer),
raw=response.json(),
)
else:
response.raise_for_status() # noqa: RET503
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
# Prepare the data payload for the Maritalk API
data = {
"messages": prompt,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"chat_mode": False,
}
# Update data payload with additional kwargs if any
data.update(kwargs)
headers = {"authorization": f"Key {self.api_key}"}
response = requests.post(self._endpoint, json=data, headers=headers)
if response.status_code == 429:
return CompletionResponse(
text="Rate limited, please try again soon",
raw=response.text,
)
elif response.ok:
answer = response.json()["answer"]
return CompletionResponse(
text=answer,
raw=response.json(),
)
else:
response.raise_for_status() # noqa: RET503
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
raise NotImplementedError(
"Maritalk does not currently support streaming completion."
)
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError(
"Maritalk does not currently support streaming completion."
)
@llm_chat_callback()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
return self.chat(messages, **kwargs)
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return self.complete(prompt, formatted, **kwargs)
@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
raise NotImplementedError(
"Maritalk does not currently support streaming completion."
)
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
raise NotImplementedError(
"Maritalk does not currently support streaming completion."
)
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
# Feel free to un-skip examples, and experimental, you will just need to
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.llms.maritalk"
[tool.llamahub.class_authors]
Maritalk = "rodrigo-f-nogueira"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
description = "llama-index llms maritalk integration"
license = "MIT"
name = "llama-index-llms-maritalk"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
llama-index-core = "^0.10.0"
[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
types-setuptools = "67.1.0.0"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment