Skip to content
Snippets Groups Projects
Unverified Commit 6ba3f55c authored by Sean Smith's avatar Sean Smith Committed by GitHub
Browse files

Contextual Generate model (#17913)

parent 5e37074b
No related branches found
No related tags found
No related merge requests found
Showing
with 413 additions and 0 deletions
poetry_requirements(
name="poetry",
)
# Contextual LLM Integration for LlamaIndex
This package provides a Contextual LLM integration for LlamaIndex.
## Installation
```bash
pip install llama-index-llms-contextual
```
## Usage
```python
from llama_index.llms.contextual import Contextual
llm = Contextual(model="contextual-clm", api_key="your_api_key")
response = llm.complete("Explain the importance of Grounded Language Models.")
```
python_sources()
from llama_index.llms.contextual.base import Contextual
__all__ = ["Contextual"]
from typing import Any, Optional
from llama_index.llms.openai_like import OpenAILike
from pydantic import Field
from typing import List
from llama_index.core.llms.callbacks import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.core.base.llms.types import (
CompletionResponse,
ChatResponse,
ChatResponseGen,
MessageRole,
ChatMessage,
)
from contextual import ContextualAI
class Contextual(OpenAILike):
"""
Generate a response using Contextual's Grounded Language Model (GLM), an LLM engineered specifically to prioritize faithfulness to in-context retrievals over parametric knowledge to reduce hallucinations in Retrieval-Augmented Generation.
The total request cannot exceed 32,000 tokens. Email glm-feedback@contextual.ai with any feedback or questions.
Examples:
`pip install llama-index-llms-contextual`
```python
from llama_index.llms.contextual import Contextual
# Set up the Contextual class with the required model and API key
llm = Contextual(model="contextual-clm", api_key="your_api_key")
# Call the complete method with a query
response = llm.complete("Explain the importance of low latency LLMs")
print(response)
```
"""
model: str = Field(
description="The model to use. Currently only supports `v1`.", default="v1"
)
api_key: str = Field(description="The API key to use.", default=None)
base_url: str = Field(
description="The base URL to use.",
default="https://api.contextual.ai/v1/generate",
)
avoid_commentary: bool = Field(
description="Flag to indicate whether the model should avoid providing additional commentary in responses. Commentary is conversational in nature and does not contain verifiable claims; therefore, commentary is not strictly grounded in available context. However, commentary may provide useful context which improves the helpfulness of responses.",
default=False,
)
client: Any = Field(default=None, description="Contextual AI Client")
def __init__(
self,
model: str,
api_key: str,
base_url: str = None,
avoid_commentary: bool = False,
**openai_llm_kwargs: Any,
) -> None:
super().__init__(
model=model,
api_key=api_key,
api_base=base_url,
is_chat_model=openai_llm_kwargs.pop("is_chat_model", True),
**openai_llm_kwargs,
)
try:
self.client = ContextualAI(api_key=api_key, base_url=base_url)
except Exception as e:
raise ValueError(f"Error initializing ContextualAI client: {e}")
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "contextual-clm"
# Synchronous Methods
@llm_completion_callback()
def complete(
self, prompt: str, knowledge: Optional[List[str]] = None, **kwargs
) -> CompletionResponse:
"""
Generate completion for the given prompt.
Args:
prompt (str): The input prompt to generate completion for.
**kwargs: Additional keyword arguments for the API request.
Returns:
str: The generated text completion.
"""
messages_list = [{"role": MessageRole.USER, "content": prompt}]
response = self._generate(
knowledge=knowledge,
messages=messages_list,
model=self.model,
system_prompt=self.system_prompt,
**kwargs,
)
return CompletionResponse(text=response)
@llm_chat_callback()
def chat(self, messages: List[ChatMessage], **kwargs) -> ChatResponse:
"""
Generate a chat response for the given messages.
"""
messages_list = [
{"role": msg.role, "content": msg.blocks[0].text} for msg in messages
]
response = self._generate(
knowledge=kwargs.get("knowledge_base", None),
messages=messages_list,
model=self.model,
system_prompt=self.system_prompt,
**kwargs,
)
return ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=response)
)
@llm_chat_callback()
def stream_chat(self, messages: List[ChatMessage], **kwargs) -> ChatResponseGen:
"""
Generate a chat response for the given messages.
"""
raise NotImplementedError("stream methods not implemented in Contextual")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs) -> ChatResponseGen:
"""
Generate a chat response for the given messages.
"""
raise NotImplementedError("stream methods not implemented in Contextual")
# ===== Async Endpoints =====
@llm_chat_callback()
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
raise NotImplementedError("async methods not implemented in Contextual")
@llm_chat_callback()
async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseAsyncGen:
raise NotImplementedError("async methods not implemented in Contextual")
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
raise NotImplementedError("async methods not implemented in Contextual")
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
raise NotImplementedError("async methods not implemented in Contextual")
def _generate(
self, knowledge, messages, system_prompt, **kwargs
) -> CompletionResponse:
"""
Generate completion for the given prompt.
"""
raw_message = self.client.generate.create(
messages=messages,
knowledge=knowledge or [],
model=self.model,
system_prompt=system_prompt,
avoid_commentary=self.avoid_commentary,
temperature=kwargs.get("temperature", 0.0),
max_new_tokens=kwargs.get("max_tokens", 1024),
top_p=kwargs.get("top_p", 1),
)
return raw_message.response
%% Cell type:markdown id: tags:
# Contextual GLM
%% Cell type:code id: tags:
``` python
!pip install llama-index-llms-contextual
```
%% Cell type:code id: tags:
``` python
from llama_index.llms.contextual import Contextual
from dotenv import load_dotenv
import os
# Set up the Contextual class with the required model and API key
# Store the API key in a .env file as CONTEXTUAL_API_KEY
load_dotenv()
llm = Contextual(model="v1", api_key=os.getenv("CONTEXTUAL_API_KEY"))
# Call the complete method with a query
llm.complete(
"Explain the importance of Grounded Language Models.",
temperature=0.5,
max_tokens=1024,
top_p=0.9,
avoid_commentary=False,
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
knowledge=["The sky is blue"],
)
```
%% Output
CompletionResponse(text="I apologize, but I am unable to provide information about Grounded Language Models. I am an AI assistant created by Contextual AI. I don't have relevant documentation about that topic, but feel free to ask me something else!", additional_kwargs={}, raw=None, logprobs=None, delta=None)
%% Cell type:code id: tags:
``` python
llm.complete(
"what color is the sky?",
knowledge=["The sky is blue"],
avoid_commentary=False,
temperature=0.9,
max_tokens=1,
)
```
%% Output
CompletionResponse(text='The sky is blue.', additional_kwargs={}, raw=None, logprobs=None, delta=None)
%% Cell type:code id: tags:
``` python
from llama_index.core.chat_engine.types import ChatMessage
llm.chat([ChatMessage(role="user", content="what color is the sky?")])
```
%% Output
---------------------------------------------------------------------------
AuthenticationError Traceback (most recent call last)
Cell In[39], line 3
1 from llama_index.core.chat_engine.types import ChatMessage
----> 3 llm.chat([ChatMessage(role="user", content="what color is the sky?")])
File ~/projects/llama_index/llama-index-core/llama_index/core/instrumentation/dispatcher.py:322, in Dispatcher.span.<locals>.wrapper(func, instance, args, kwargs)
319 _logger.debug(f"Failed to reset active_span_id: {e}")
321 try:
--> 322 result = func(*args, **kwargs)
323 if isinstance(result, asyncio.Future):
324 # If the result is a Future, wrap it
325 new_future = asyncio.ensure_future(result)
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/llama_index/llms/openai_like/base.py:117, in OpenAILike.chat(self, messages, **kwargs)
114 completion_response = self.complete(prompt, formatted=True, **kwargs)
115 return completion_response_to_chat_response(completion_response)
--> 117 return super().chat(messages, **kwargs)
File ~/projects/llama_index/llama-index-core/llama_index/core/instrumentation/dispatcher.py:322, in Dispatcher.span.<locals>.wrapper(func, instance, args, kwargs)
319 _logger.debug(f"Failed to reset active_span_id: {e}")
321 try:
--> 322 result = func(*args, **kwargs)
323 if isinstance(result, asyncio.Future):
324 # If the result is a Future, wrap it
325 new_future = asyncio.ensure_future(result)
File ~/projects/llama_index/llama-index-core/llama_index/core/llms/callbacks.py:173, in llm_chat_callback.<locals>.wrap.<locals>.wrapped_llm_chat(_self, messages, **kwargs)
164 event_id = callback_manager.on_event_start(
165 CBEventType.LLM,
166 payload={
(...)
170 },
171 )
172 try:
--> 173 f_return_val = f(_self, messages, **kwargs)
174 except BaseException as e:
175 callback_manager.on_event_end(
176 CBEventType.LLM,
177 payload={EventPayload.EXCEPTION: e},
178 event_id=event_id,
179 )
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/llama_index/llms/openai/base.py:379, in OpenAI.chat(self, messages, **kwargs)
377 else:
378 chat_fn = completion_to_chat_decorator(self._complete)
--> 379 return chat_fn(messages, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/llama_index/llms/openai/base.py:107, in llm_retry_decorator.<locals>.wrapper(self, *args, **kwargs)
98 return f(self, *args, **kwargs)
100 retry = create_retry_decorator(
101 max_retries=max_retries,
102 random_exponential=True,
(...)
105 max_seconds=20,
106 )
--> 107 return retry(f)(self, *args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/tenacity/__init__.py:336, in BaseRetrying.wraps.<locals>.wrapped_f(*args, **kw)
334 copy = self.copy()
335 wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
--> 336 return copy(f, *args, **kw)
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/tenacity/__init__.py:475, in Retrying.__call__(self, fn, *args, **kwargs)
473 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
474 while True:
--> 475 do = self.iter(retry_state=retry_state)
476 if isinstance(do, DoAttempt):
477 try:
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/tenacity/__init__.py:376, in BaseRetrying.iter(self, retry_state)
374 result = None
375 for action in self.iter_state.actions:
--> 376 result = action(retry_state)
377 return result
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/tenacity/__init__.py:398, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
396 def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:
397 if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 398 self._add_action_func(lambda rs: rs.outcome.result())
399 return
401 if self.after is not None:
File ~/.pyenv/versions/3.10.16/lib/python3.10/concurrent/futures/_base.py:451, in Future.result(self, timeout)
449 raise CancelledError()
450 elif self._state == FINISHED:
--> 451 return self.__get_result()
453 self._condition.wait(timeout)
455 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
File ~/.pyenv/versions/3.10.16/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
401 if self._exception:
402 try:
--> 403 raise self._exception
404 finally:
405 # Break a reference cycle with the exception in self._exception
406 self = None
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/tenacity/__init__.py:478, in Retrying.__call__(self, fn, *args, **kwargs)
476 if isinstance(do, DoAttempt):
477 try:
--> 478 result = fn(*args, **kwargs)
479 except BaseException: # noqa: B902
480 retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/llama_index/llms/openai/base.py:475, in OpenAI._chat(self, messages, **kwargs)
469 message_dicts = to_openai_message_dicts(
470 messages,
471 model=self.model,
472 )
474 if self.reuse_client:
--> 475 response = client.chat.completions.create(
476 messages=message_dicts,
477 stream=False,
478 **self._get_model_kwargs(**kwargs),
479 )
480 else:
481 with client:
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/openai/_utils/_utils.py:279, in required_args.<locals>.inner.<locals>.wrapper(*args, **kwargs)
277 msg = f"Missing required argument: {quote(missing[0])}"
278 raise TypeError(msg)
--> 279 return func(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/openai/resources/chat/completions/completions.py:879, in Completions.create(self, messages, model, audio, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, metadata, modalities, n, parallel_tool_calls, prediction, presence_penalty, reasoning_effort, response_format, seed, service_tier, stop, store, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, extra_headers, extra_query, extra_body, timeout)
837 @required_args(["messages", "model"], ["messages", "model", "stream"])
838 def create(
839 self,
(...)
876 timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
877 ) -> ChatCompletion | Stream[ChatCompletionChunk]:
878 validate_response_format(response_format)
--> 879 return self._post(
880 "/chat/completions",
881 body=maybe_transform(
882 {
883 "messages": messages,
884 "model": model,
885 "audio": audio,
886 "frequency_penalty": frequency_penalty,
887 "function_call": function_call,
888 "functions": functions,
889 "logit_bias": logit_bias,
890 "logprobs": logprobs,
891 "max_completion_tokens": max_completion_tokens,
892 "max_tokens": max_tokens,
893 "metadata": metadata,
894 "modalities": modalities,
895 "n": n,
896 "parallel_tool_calls": parallel_tool_calls,
897 "prediction": prediction,
898 "presence_penalty": presence_penalty,
899 "reasoning_effort": reasoning_effort,
900 "response_format": response_format,
901 "seed": seed,
902 "service_tier": service_tier,
903 "stop": stop,
904 "store": store,
905 "stream": stream,
906 "stream_options": stream_options,
907 "temperature": temperature,
908 "tool_choice": tool_choice,
909 "tools": tools,
910 "top_logprobs": top_logprobs,
911 "top_p": top_p,
912 "user": user,
913 },
914 completion_create_params.CompletionCreateParams,
915 ),
916 options=make_request_options(
917 extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
918 ),
919 cast_to=ChatCompletion,
920 stream=stream or False,
921 stream_cls=Stream[ChatCompletionChunk],
922 )
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/openai/_base_client.py:1296, in SyncAPIClient.post(self, path, cast_to, body, options, files, stream, stream_cls)
1282 def post(
1283 self,
1284 path: str,
(...)
1291 stream_cls: type[_StreamT] | None = None,
1292 ) -> ResponseT | _StreamT:
1293 opts = FinalRequestOptions.construct(
1294 method="post", url=path, json_data=body, files=to_httpx_files(files), **options
1295 )
-> 1296 return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/openai/_base_client.py:973, in SyncAPIClient.request(self, cast_to, options, remaining_retries, stream, stream_cls)
970 else:
971 retries_taken = 0
--> 973 return self._request(
974 cast_to=cast_to,
975 options=options,
976 stream=stream,
977 stream_cls=stream_cls,
978 retries_taken=retries_taken,
979 )
File ~/Library/Caches/pypoetry/virtualenvs/llama-index-VCjo73HL-py3.10/lib/python3.10/site-packages/openai/_base_client.py:1077, in SyncAPIClient._request(self, cast_to, options, retries_taken, stream, stream_cls)
1074 err.response.read()
1076 log.debug("Re-raising status error")
-> 1077 raise self._make_status_error_from_response(err.response) from None
1079 return self._process_response(
1080 cast_to=cast_to,
1081 options=options,
(...)
1085 retries_taken=retries_taken,
1086 )
AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: key-KCj3*****************************************izwo. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.llms.contextual"
[tool.llamahub.class_authors]
Contextual = "sean-smith"
[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Sean Smith <sean.smith@contextual.ai>"]
description = "llama-index contextual integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-contextual"
readme = "README.md"
version = "0.0.1"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
llama-index-llms-openai-like = "^0.3.3"
contextual-client = "^0.4.0"
[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"
[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"
[[tool.poetry.packages]]
include = "llama_index/"
python_sources()
from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.contextual import Contextual
def test_llm_class():
names_of_base_classes = [b.__name__ for b in Contextual.__mro__]
assert BaseLLM.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment