From 6bdc95b7f3883247ba9b437a3fe8f557947bc0ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=92=9F=E7=90=A6?= <mrgao.ary@gmail.com> Date: Fri, 15 Mar 2024 11:00:10 +0800 Subject: [PATCH] Add Solar chat model (#11710) --- docs/examples/llm/solar.ipynb | 85 ++++++ docs/module_guides/models/llms/modules.md | 9 + .../llms/llama-index-llms-solar/.gitignore | 153 +++++++++++ .../llms/llama-index-llms-solar/BUILD | 3 + .../llms/llama-index-llms-solar/Makefile | 17 ++ .../llms/llama-index-llms-solar/README.md | 1 + .../llama_index/llms/solar/BUILD | 1 + .../llama_index/llms/solar/__init__.py | 3 + .../llama_index/llms/solar/base.py | 241 ++++++++++++++++++ .../llama-index-llms-solar/pyproject.toml | 64 +++++ 10 files changed, 577 insertions(+) create mode 100644 docs/examples/llm/solar.ipynb create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/.gitignore create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/Makefile create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/README.md create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/base.py create mode 100644 llama-index-integrations/llms/llama-index-llms-solar/pyproject.toml diff --git a/docs/examples/llm/solar.ipynb b/docs/examples/llm/solar.ipynb new file mode 100644 index 0000000000..b0ea595833 --- /dev/null +++ b/docs/examples/llm/solar.ipynb @@ -0,0 +1,85 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cae1b4a8", + "metadata": {}, + "source": [ + "# Solar LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "715d392e", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index-llms-solar" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fdc2dc3-1454-41e9-8862-9dfd75b5b61f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"SOLAR_API_KEY\"] = \"SOLAR_API_KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26b168b8-9ebf-479d-ac53-28bc952da354", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "assistant: Mother also went into the room.\n" + ] + } + ], + "source": [ + "# from llama_index.llms import\n", + "from llama_index.llms.solar import Solar\n", + "from llama_index.core.base.llms.types import ChatMessage, MessageRole\n", + "\n", + "llm = Solar(model=\"solar-1-mini-chat\", is_chat_model=True)\n", + "response = llm.chat(\n", + " messages=[\n", + " ChatMessage(role=\"user\", content=\"아버지가방에들어가셨다\"),\n", + " ChatMessage(role=\"assistant\", content=\"Father went into his room\"),\n", + " ChatMessage(role=\"user\", content=\"엄마도들어가셨다\"),\n", + " ]\n", + ")\n", + "\n", + "print(response)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/module_guides/models/llms/modules.md b/docs/module_guides/models/llms/modules.md index 2e169ba263..bcbb6645ed 100644 --- a/docs/module_guides/models/llms/modules.md +++ b/docs/module_guides/models/llms/modules.md @@ -353,6 +353,15 @@ maxdepth: 1 /examples/llm/sagemaker_endpoint_llm.ipynb ``` +## Solar + +```{toctree} +--- +maxdepth: 1 +--- +/examples/llm/solar.ipynb +``` + ## Together.ai ```{toctree} diff --git a/llama-index-integrations/llms/llama-index-llms-solar/.gitignore b/llama-index-integrations/llms/llama-index-llms-solar/.gitignore new file mode 100644 index 0000000000..990c18de22 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-solar/BUILD b/llama-index-integrations/llms/llama-index-llms-solar/BUILD new file mode 100644 index 0000000000..0896ca890d --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-solar/Makefile b/llama-index-integrations/llms/llama-index-llms-solar/Makefile new file mode 100644 index 0000000000..b9eab05aa3 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-solar/README.md b/llama-index-integrations/llms/llama-index-llms-solar/README.md new file mode 100644 index 0000000000..74c256c1f6 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/README.md @@ -0,0 +1 @@ +# LlamaIndex Llms Integration: Solar diff --git a/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/BUILD b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/BUILD new file mode 100644 index 0000000000..db46e8d6c9 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/__init__.py b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/__init__.py new file mode 100644 index 0000000000..b5a36ee495 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/__init__.py @@ -0,0 +1,3 @@ +from llama_index.llms.solar.base import Solar + +__all__ = ["Solar"] diff --git a/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/base.py b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/base.py new file mode 100644 index 0000000000..c954d051ad --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/llama_index/llms/solar/base.py @@ -0,0 +1,241 @@ +from typing import ( + Any, + Callable, + Dict, + Optional, + Sequence, + Union, + Tuple, +) +from llama_index.legacy.llms.generic_utils import get_from_param_or_env +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) +import httpx +from llama_index.core.bridge.pydantic import Field +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.core.base.llms.generic_utils import ( + async_stream_completion_response_to_chat_response, + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks import CallbackManager +from llama_index.llms.openai.base import OpenAI, Tokenizer +from transformers import AutoTokenizer + +DEFAULT_SOLAR_API_BASE = "https://api.upstage.ai/v1/solar" +DEFAULT_SOLAR_MODEL = "solar-1-mini-chat" + + +class Solar(OpenAI): + api_key: str = Field(default=None, description="The SOLAR API key.") + api_base: str = Field(default="", description="The base URL for SOLAR API.") + + model: str = Field( + default="solar-1-mini-chat", description="The SOLAR model to use." + ) + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=LLMMetadata.__fields__["context_window"].field_info.description, + ) + is_chat_model: bool = Field( + default=False, + description=LLMMetadata.__fields__["is_chat_model"].field_info.description, + ) + is_function_calling_model: bool = Field( + default=False, + description=LLMMetadata.__fields__[ + "is_function_calling_model" + ].field_info.description, + ) + tokenizer: Union[Tokenizer, str, None] = Field( + default=None, + description=( + "An instance of a tokenizer object that has an encode method, or the name" + " of a tokenizer model from Hugging Face. If left as None, then this" + " disables inference of max_tokens." + ), + ) + + def __init__( + self, + model: str = DEFAULT_SOLAR_MODEL, + temperature: float = 0.1, + max_tokens: Optional[int] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + timeout: float = 60.0, + reuse_client: bool = True, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + default_headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.Client] = None, + # base class + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + **kwargs: Any, + ) -> None: + api_key, api_base = resolve_solar_credentials( + api_key=api_key, + api_base=api_base, + ) + + super().__init__( + model=model, + temperature=temperature, + max_tokens=max_tokens, + additional_kwargs=additional_kwargs, + max_retries=max_retries, + callback_manager=callback_manager, + api_key=api_key, + api_version=api_version, + api_base=api_base, + timeout=timeout, + reuse_client=reuse_client, + default_headers=default_headers, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + **kwargs, + ) + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_tokens or -1, + is_chat_model=self.is_chat_model, + is_function_calling_model=self.is_function_calling_model, + model_name=self.model, + ) + + @property + def _tokenizer(self) -> Optional[Tokenizer]: + if isinstance(self.tokenizer, str): + return AutoTokenizer.from_pretrained(self.tokenizer) + return self.tokenizer + + @classmethod + def class_name(cls) -> str: + return "Solar" + + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """Complete the prompt.""" + if not formatted: + prompt = self.completion_to_prompt(prompt) + + return super().complete(prompt, **kwargs) + + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """Stream complete the prompt.""" + if not formatted: + prompt = self.completion_to_prompt(prompt) + + return super().stream_complete(prompt, **kwargs) + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """Chat with the model.""" + if not self.metadata.is_chat_model: + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion_response) + + return super().chat(messages, **kwargs) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + if not self.metadata.is_chat_model: + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) + return stream_completion_response_to_chat_response(completion_response) + + return super().stream_chat(messages, **kwargs) + + # -- Async methods -- + + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """Complete the prompt.""" + if not formatted: + prompt = self.completion_to_prompt(prompt) + + return await super().acomplete(prompt, **kwargs) + + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + """Stream complete the prompt.""" + if not formatted: + prompt = self.completion_to_prompt(prompt) + + return await super().astream_complete(prompt, **kwargs) + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + """Chat with the model.""" + if not self.metadata.is_chat_model: + prompt = self.messages_to_prompt(messages) + completion_response = await self.acomplete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion_response) + + return await super().achat(messages, **kwargs) + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + if not self.metadata.is_chat_model: + prompt = self.messages_to_prompt(messages) + completion_response = await self.astream_complete( + prompt, formatted=True, **kwargs + ) + return async_stream_completion_response_to_chat_response( + completion_response + ) + + return await super().astream_chat(messages, **kwargs) + + +def resolve_solar_credentials( + api_key: Optional[str] = None, + api_base: Optional[str] = None, +) -> Tuple[Optional[str], str]: + """ "Resolve SOLAR credentials. + + The order of precedence is: + 1. param + 2. env + 3. solar module + 4. default + """ + # resolve from param or env + api_key = get_from_param_or_env("api_key", api_key, "SOLAR_API_KEY", "") + api_base = get_from_param_or_env("api_base", api_base, "SOLAR_API_BASE", "") + + final_api_key = api_key or "" + final_api_base = api_base or DEFAULT_SOLAR_API_BASE + + return final_api_key, str(final_api_base) diff --git a/llama-index-integrations/llms/llama-index-llms-solar/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-solar/pyproject.toml new file mode 100644 index 0000000000..619bc2e8be --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-solar/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.solar" + +[tool.llamahub.class_authors] +Solar = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Your Name <you@example.com>"] +description = "llama-index llms solar integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-llms-solar" +readme = "README.md" +version = "0.1.3" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.1" +llama-index-llms-openai = "^0.1.1" +transformers = "^4.37.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" -- GitLab