Formatting Messages for Gemini Live

This commit is contained in:
Ankur Duggal 2025-03-26 10:27:02 -07:00
parent 9f3426a42d
commit 6d77f38bb1
2 changed files with 277 additions and 51 deletions

View file

@ -1,7 +1,7 @@
import base64
import json
import asyncio
from typing import Any, Optional, Dict, Union
from typing import Any, Optional, Dict, Union, List
import websockets
@ -31,6 +31,10 @@ class GeminiRealTimeStreaming(RealTimeStreaming):
self.system_instruction = system_instruction
self.tools = tools
# Track connection state manually
self.client_ws_open = True
self.backend_ws_open = True
if self.vertex_project and self.vertex_location:
self.model_resource_name = f"projects/{self.vertex_project}/locations/{self.vertex_location}/publishers/google/models/{self.model_id}"
else:
@ -47,6 +51,10 @@ class GeminiRealTimeStreaming(RealTimeStreaming):
if self.system_instruction:
setup_payload["system_instruction"] = self.system_instruction
# Add tools to the setup payload if they exist
if self.tools and len(self.tools) > 0:
setup_payload["tools"] = self.tools
setup_message = {"setup": setup_payload}
print(f"Gemini Setup Message: {json.dumps(setup_message)}")
@ -56,71 +64,289 @@ class GeminiRealTimeStreaming(RealTimeStreaming):
async def wait_for_setup_complete(self):
"""Waits for the setupComplete message from the Gemini backend."""
try:
await self.backend_ws.recv()
response = await self.backend_ws.recv()
print(f"Setup response: {response}")
# Parse response to check if it's a valid setup completion
if isinstance(response, str):
response_data = json.loads(response)
if "setupComplete" not in response_data:
print(f"WARNING: Unexpected setup response format: {response}")
return True
except websockets.exceptions.ConnectionClosed as e:
print(f"Connection closed while waiting for Gemini setup complete: {e}")
if not self.websocket.closed:
await self.websocket.close(code=e.code, reason=f"Backend connection closed during setup: {e.reason}")
await self.safely_close_websocket(self.websocket, code=e.code, reason=f"Backend connection closed during setup: {e.reason}")
self.backend_ws_open = False
return False
except json.JSONDecodeError as e:
print(f"Failed to decode JSON during setup: {e}")
await self.websocket.close(code=1011, reason="Invalid JSON received during setup")
await self.backend_ws.close(code=1011, reason="Invalid JSON received during setup")
await self.safely_close_websocket(self.websocket, code=1011, reason="Invalid JSON received during setup")
await self.safely_close_websocket(self.backend_ws, code=1011, reason="Invalid JSON received during setup")
self.client_ws_open = False
self.backend_ws_open = False
return False
except Exception as e:
print(f"An unexpected error occurred during Gemini setup: {e}")
if not self.websocket.closed:
await self.websocket.close(code=1011, reason=f"Unexpected setup error: {e}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=1011, reason=f"Unexpected setup error: {e}")
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Unexpected setup error: {e}")
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Unexpected setup error: {e}")
self.client_ws_open = False
self.backend_ws_open = False
return False
async def safely_close_websocket(self, ws, code=1000, reason="Normal closure"):
"""Safely close a websocket without relying on .closed attribute"""
try:
await ws.close(code=code, reason=reason)
except Exception as e:
print(f"Error closing websocket: {e}")
finally:
if ws == self.websocket:
self.client_ws_open = False
elif ws == self.backend_ws:
self.backend_ws_open = False
async def is_websocket_open(self, ws):
"""Check if websocket is open by trying a simple operation"""
try:
# For some websocket implementations, we can use an attribute
if hasattr(ws, 'closed'):
return not ws.closed
# For others, a state check might work
if hasattr(ws, 'state') and hasattr(ws.state, 'name'):
return ws.state.name == 'OPEN'
# Default: assume it's open if it's in our tracking variables
return (ws == self.websocket and self.client_ws_open) or \
(ws == self.backend_ws and self.backend_ws_open)
except Exception:
# If we can't determine, assume it's closed for safety
return False
async def backend_to_client_send_messages(self):
"""Forwards messages received from Gemini backend to the client."""
"""Receives messages from Gemini, transforms them to LiteLLM format, and forwards to the client."""
try:
while True:
while self.backend_ws_open and self.client_ws_open:
message = await self.backend_ws.recv()
if isinstance(message, bytes):
await self.websocket.send_bytes(message)
else:
await self.websocket.send_text(message)
# Log the raw message received from Gemini for debugging
print(f"Received raw from Gemini backend: {message}")
# Store the original raw message for logging purposes
self.store_message(message)
transformed_message_str = None
try:
if isinstance(message, str):
gemini_data = json.loads(message)
# --- Transformation Logic ---
# Assume Gemini response structure (adjust if different)
# Example: {"candidates": [{"content": {"role": "model", "parts": [{"text": "..."}]}}]}
extracted_text = ""
lite_llm_role = "assistant" # Default LiteLLM role for model output
candidates = gemini_data.get("candidates")
if isinstance(candidates, list) and len(candidates) > 0:
# Process the first candidate
candidate = candidates[0]
content = candidate.get("content")
if isinstance(content, dict):
# Map Gemini's 'model' role to LiteLLM's 'assistant' role
if content.get("role") == "model":
lite_llm_role = "assistant"
else:
# Handle other potential roles if needed, or keep default
lite_llm_role = content.get("role", "assistant")
parts = content.get("parts")
if isinstance(parts, list) and len(parts) > 0:
# Concatenate text from all parts (or just take the first?)
# For simplicity, let's concatenate text parts.
text_parts = [part.get("text", "") for part in parts if isinstance(part, dict) and "text" in part]
extracted_text = "".join(text_parts)
# Add other potential extraction paths if Gemini's format varies
# For example, sometimes streaming responses might be simpler:
# elif "text" in gemini_data:
# extracted_text = gemini_data["text"]
# lite_llm_role = "assistant" # Assume model role for simple text
if extracted_text:
# Construct the LiteLLM standard message format using 'parts'
lite_llm_message = {
"role": lite_llm_role,
"parts": [{"text": extracted_text}]
# Alternatively, if your client prefers 'content':
# "content": extracted_text
}
transformed_message_str = json.dumps(lite_llm_message)
else:
# Handle non-content messages (e.g., metadata, finish reasons)
# Option 1: Forward them raw if the client needs them
# transformed_message_str = message
# Option 2: Skip forwarding non-content messages
print(f"No text content extracted from Gemini message, skipping forward: {message}")
continue # Skip to next message
elif isinstance(message, bytes):
# If Gemini sends bytes, decide how to handle (e.g., forward raw)
print("Received bytes from Gemini, forwarding directly.")
if await self.is_websocket_open(self.websocket):
await self.websocket.send_bytes(message)
continue # Skip JSON transformation for bytes
else:
print(f"Received unexpected message type from Gemini: {type(message)}")
continue # Skip unknown types
except json.JSONDecodeError:
print(f"Failed to decode JSON from Gemini: {message}")
# Decide how to handle non-JSON messages (e.g., forward raw string)
if isinstance(message, str) and await self.is_websocket_open(self.websocket):
print("Forwarding non-JSON string message raw.")
await self.websocket.send_text(message)
continue # Skip processing if not JSON
except Exception as e:
print(f"Error processing/transforming message from Gemini: {e}")
# Decide how to handle errors (e.g., skip message)
continue # Skip message on transformation error
# Send the transformed message (if available) to the client
if transformed_message_str and await self.is_websocket_open(self.websocket):
print(f"Sending transformed to client: {transformed_message_str}")
await self.websocket.send_text(transformed_message_str)
except websockets.exceptions.ConnectionClosed as e:
print(f"Gemini backend connection closed: {e.code} {e.reason}")
if not self.websocket.closed:
await self.websocket.close(code=e.code, reason=e.reason)
self.backend_ws_open = False
if await self.is_websocket_open(self.websocket):
await self.safely_close_websocket(self.websocket, code=e.code, reason=e.reason)
except Exception as e:
print(f"Error receiving from Gemini backend or sending to client: {e}")
if not self.websocket.closed:
await self.websocket.close(code=1011, reason=f"Error forwarding message: {e}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=1011, reason=f"Error forwarding message: {e}")
if await self.is_websocket_open(self.websocket):
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Error forwarding message: {e}")
if await self.is_websocket_open(self.backend_ws):
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Error forwarding message: {e}")
finally:
await self.log_messages()
# Log accumulated messages if needed (self.log_messages() might need adjustment
# if it relies on the stored messages being in a specific format)
# await self.log_messages()
print("Backend-to-client message forwarding stopped.")
async def client_to_backend_send_messages(self):
"""Forwards messages received from the client to the Gemini backend."""
try:
while True:
message = await self.websocket.receive_text()
self.store_input(message=message)
await self.backend_ws.send(message)
async def client_to_backend_send_messages(self):
"""Receives messages from the client (e.g., Twilio app), formats them
correctly for Gemini, and forwards to the Gemini backend."""
try:
while self.client_ws_open and self.backend_ws_open:
message_text = await self.websocket.receive_text()
print(f"Received from client: {message_text}")
# Parse the message to check if it has the correct format
try:
message_data = json.loads(message_text)
final_message_to_send = None # Store the final formatted message dict
# --- START AUDIO HANDLING ---
# Check if the message is audio input from the client (Twilio script)
if "audio" in message_data and "type" in message_data and message_data["type"] == "input_audio_buffer.append":
# Construct the message using the realtimeInput format
# based on the provided AudioInputMessage example.
audio_data_base64 = message_data["audio"]
# Determine MIME type based on Twilio setup.
# twilio_example_litellm.py sets "input_audio_format": "g711_ulaw".
# The standard MIME type for G.711 µ-law is audio/mulaw or audio/pcmu.
# It typically runs at 8000 Hz.
sample_rate = 8000
# Use the format from the example: audio/pcm;rate=... or audio/mulaw
# Let's try the specific mulaw type first.
# mime_type = f"audio/pcm;rate={sample_rate}" # As per example
mime_type = "audio/mulaw" # Standard for G.711 µ-law
# Structure according to the RealtimeInput/MediaChunk model example
# Use the exact field names 'mimeType' and 'mediaChunks'
media_chunk = {
"mimeType": mime_type,
"data": audio_data_base64
}
realtime_input_payload = {
"mediaChunks": [media_chunk]
}
final_message_to_send = {"realtimeInput": realtime_input_payload}
# --- END AUDIO HANDLING ---
# --- START OTHER MESSAGE HANDLING (Text, Tools, etc.) ---
# Handle text/history messages potentially coming in {"contents": [...]} format
elif "contents" in message_data and isinstance(message_data.get("contents"), list):
# Adapt the incoming 'contents' list (assumed to be turns) to the 'clientContent' format
content_turns = message_data["contents"]
valid_turns = []
for turn in content_turns:
if isinstance(turn, dict) and "role" in turn and "parts" in turn:
valid_turns.append(turn)
else:
print(f"WARNING: Skipping invalid turn structure in 'contents': {turn}")
if valid_turns:
is_complete = message_data.get("turn_complete", True)
final_message_to_send = {"clientContent": {"turns": valid_turns, "turn_complete": is_complete}}
else:
print(f"WARNING: No valid turns found in 'contents', cannot send message.")
continue
# For tool response messages, assume client sends correct format
elif "toolResponse" in message_data:
final_message_to_send = message_data # Pass through directly
# Handle potential cleanup of unsupported fields if message wasn't reformatted
elif final_message_to_send is None:
# If it wasn't audio, text, or tool response, check for and remove common unsupported fields
# before potentially forwarding (or deciding not to forward)
unsupported_fields = ["type", "session"]
cleaned_data = {k: v for k, v in message_data.items() if k not in unsupported_fields}
if cleaned_data != message_data:
print(f"WARNING: Removed unsupported fields. Result: {cleaned_data}")
# Decide if this cleaned_data is a valid Gemini message (e.g., setup?)
# For now, let's assume if it wasn't handled above, it's not valid to send.
# If you need to handle other types like 'setup', add specific elif blocks.
print(f"WARNING: Message from client is not a recognized/handled format: {message_data}")
continue # Skip sending unrecognized formats
# --- END OTHER MESSAGE HANDLING ---
# Convert the final formatted message back to string for sending
if final_message_to_send:
message_to_send_str = json.dumps(final_message_to_send)
else:
# Should not happen if logic is correct, but as a fallback
print(f"ERROR: final_message_to_send is None after processing.")
continue
except json.JSONDecodeError:
print(f"WARNING: Received non-JSON message from client, cannot process for Gemini: {message_text}")
continue # Skip non-JSON messages
except Exception as e:
print(f"ERROR: Failed processing client message: {e}")
continue # Skip sending this message on processing error
self.store_input(message=final_message_to_send)
print(f"Sending to Gemini backend: {message_to_send_str}")
await self.backend_ws.send(message_to_send_str)
except websockets.exceptions.ConnectionClosed as e:
print(f"Client connection closed: {e.code} {e.reason}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=e.code, reason=e.reason)
self.client_ws_open = False
if await self.is_websocket_open(self.backend_ws):
await self.safely_close_websocket(self.backend_ws, code=e.code, reason=e.reason)
except Exception as e:
print(f"Error receiving from client or sending to Gemini backend: {e}")
if not self.websocket.closed:
await self.websocket.close(code=1011, reason=f"Error forwarding message: {e}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=1011, reason=f"Error forwarding message: {e}")
if await self.is_websocket_open(self.websocket):
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Error forwarding message: {e}")
if await self.is_websocket_open(self.backend_ws):
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Error forwarding message: {e}")
finally:
print("Client-to-backend message forwarding stopped.")
async def bidirectional_forward(self):
"""Orchestrates the Gemini WebSocket session: setup and message forwarding."""
try:
@ -152,21 +378,21 @@ class GeminiRealTimeStreaming(RealTimeStreaming):
except websockets.exceptions.ConnectionClosed as e:
print(f"A connection closed unexpectedly during bidirectional forward setup or task management: {e}")
if not self.websocket.closed:
await self.websocket.close(code=e.code, reason=f"Peer connection closed: {e.reason}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=e.code, reason=f"Peer connection closed: {e.reason}")
if await self.is_websocket_open(self.websocket):
await self.safely_close_websocket(self.websocket, code=e.code, reason=f"Peer connection closed: {e.reason}")
if await self.is_websocket_open(self.backend_ws):
await self.safely_close_websocket(self.backend_ws, code=e.code, reason=f"Peer connection closed: {e.reason}")
except Exception as e:
print(f"An unexpected error occurred in bidirectional_forward: {e}")
if not self.websocket.closed:
await self.websocket.close(code=1011, reason=f"Forwarding error: {e}")
if not self.backend_ws.closed:
await self.backend_ws.close(code=1011, reason=f"Forwarding error: {e}")
if await self.is_websocket_open(self.websocket):
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Forwarding error: {e}")
if await self.is_websocket_open(self.backend_ws):
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Forwarding error: {e}")
finally:
if not self.websocket.closed:
if await self.is_websocket_open(self.websocket):
print("Closing client websocket in finally block.")
await self.websocket.close()
if not self.backend_ws.closed:
await self.safely_close_websocket(self.websocket)
if await self.is_websocket_open(self.backend_ws):
print("Closing backend websocket in finally block.")
await self.backend_ws.close()
await self.safely_close_websocket(self.backend_ws)
print("bidirectional_forward cleanup complete.")

View file

@ -5,6 +5,7 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
"""
import json
from typing import Any, Optional
import os
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.gemini_realtime_streaming import GeminiRealTimeStreaming
@ -59,7 +60,6 @@ class GeminiLive(VertexLLM):
import websockets
except ImportError:
raise ImportError("Websockets package not installed. Please install it with `pip install websockets`")
if api_base is None:
raise ValueError("api_base is required for Gemini calls")
@ -98,7 +98,7 @@ class GeminiLive(VertexLLM):
url,
extra_headers=headers,
) as backend_ws:
await self._send_setup_message(backend_ws, model, config)
# await self._send_setup_message(backend_ws, model, config)
realtime_streaming = GeminiRealTimeStreaming(
websocket, backend_ws, model, config, logging_obj, vertex_location, vertex_project)
await realtime_streaming.bidirectional_forward()