Skip to content
Snippets Groups Projects
Unverified Commit b6069d6a authored by Haotian Zhang's avatar Haotian Zhang Committed by GitHub
Browse files

Fix Stream Chat for Gemini MM LLM (#9616)

parent f283456b
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from llama_index.bridge.pydantic import Field, PrivateAttr ...@@ -7,6 +7,7 @@ from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
from llama_index.llms.gemini_utils import ( from llama_index.llms.gemini_utils import (
ROLES_FROM_GEMINI,
chat_from_gemini_response, chat_from_gemini_response,
chat_message_to_gemini, chat_message_to_gemini,
completion_from_gemini_response, completion_from_gemini_response,
...@@ -178,8 +179,28 @@ class GeminiMultiModal(MultiModalLLM): ...@@ -178,8 +179,28 @@ class GeminiMultiModal(MultiModalLLM):
) -> ChatResponseGen: ) -> ChatResponseGen:
*history, next_msg = map(chat_message_to_gemini, messages) *history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history) chat = self._model.start_chat(history=history)
it = chat.send_message(next_msg, stream=True) response = chat.send_message(next_msg, stream=True)
yield from map(chat_from_gemini_response, it)
def gen() -> ChatResponseGen:
content = ""
for r in response:
top_candidate = r.candidates[0]
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback)
),
}
content += content_delta
yield ChatResponse(
message=ChatMessage(role=role, content=content),
delta=content_delta,
raw=raw,
)
return gen()
async def acomplete( async def acomplete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
...@@ -215,10 +236,25 @@ class GeminiMultiModal(MultiModalLLM): ...@@ -215,10 +236,25 @@ class GeminiMultiModal(MultiModalLLM):
) -> ChatResponseAsyncGen: ) -> ChatResponseAsyncGen:
*history, next_msg = map(chat_message_to_gemini, messages) *history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history) chat = self._model.start_chat(history=history)
ait = await chat.send_message_async(next_msg, stream=True) response = await chat.send_message_async(next_msg, stream=True)
async def gen() -> ChatResponseAsyncGen: async def gen() -> ChatResponseAsyncGen:
async for comp in ait: content = ""
yield chat_from_gemini_response(comp) for r in response:
top_candidate = r.candidates[0]
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback)
),
}
content += content_delta
yield ChatResponse(
message=ChatMessage(role=role, content=content),
delta=content_delta,
raw=raw,
)
return gen() return gen()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment