Skip to content
Snippets Groups Projects
Commit c40fb281 authored by João Galego's avatar João Galego
Browse files

Added support for multimodal inputs and model-specific inference params

parent a0f61923
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ Classes:
"""
import json
from typing import List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import os
from time import sleep
import tiktoken
......@@ -138,11 +138,12 @@ class BedrockEncoder(DenseEncoder):
) from err
return bedrock_client
def __call__(self, docs: List[str]) -> List[List[float]]:
def __call__(self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None) -> List[List[float]]:
"""Generates embeddings for the given documents.
Args:
docs: A list of strings representing the documents to embed.
model_kwargs: A dictionary of model-specific inference parameters.
Returns:
A list of lists, where each inner list contains the embedding values for a
......@@ -168,11 +169,25 @@ class BedrockEncoder(DenseEncoder):
embeddings = []
if self.name and "amazon" in self.name:
for doc in docs:
embedding_body = json.dumps(
{
"inputText": doc,
}
)
embedding_body = {}
if isinstance(doc, dict):
embedding_body['inputText'] = doc.get('text')
embedding_body['inputImage'] = doc.get('image') # expects a base64-encoded image
else:
embedding_body['inputText'] = doc
# Add model-specific inference parameters
if model_kwargs:
embedding_body = embedding_body | model_kwargs
# Clean up null values
embedding_body = {k: v for k, v in embedding_body.items() if v}
# Format payload
embedding_body = json.dumps(embedding_body)
response = self.client.invoke_model(
body=embedding_body,
modelId=self.name,
......@@ -184,9 +199,19 @@ class BedrockEncoder(DenseEncoder):
elif self.name and "cohere" in self.name:
chunked_docs = self.chunk_strings(docs)
for chunk in chunked_docs:
chunk = json.dumps(
{"texts": chunk, "input_type": self.input_type}
)
chunk = {
'texts': chunk,
'input_type': self.input_type
}
# Add model-specific inference parameters
# Note: if specified, input_type will be overwritten by model_kwargs
if model_kwargs:
chunk = chunk | model_kwargs
# Format payload
chunk = json.dumps(chunk)
response = self.client.invoke_model(
body=chunk,
modelId=self.name,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment