mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge 6d77f38bb1
into b82af5b826
This commit is contained in:
commit
ffb42e3973
4 changed files with 559 additions and 1 deletions
398
litellm/litellm_core_utils/gemini_realtime_streaming.py
Normal file
398
litellm/litellm_core_utils/gemini_realtime_streaming.py
Normal file
|
@ -0,0 +1,398 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Optional, Dict, Union, List
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.realtime_streaming import (
|
||||||
|
RealTimeStreaming,
|
||||||
|
)
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
|
||||||
|
class GeminiRealTimeStreaming(RealTimeStreaming):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
websocket: Any,
|
||||||
|
backend_ws: Any,
|
||||||
|
model: str,
|
||||||
|
config: dict,
|
||||||
|
logging_obj: Optional[LiteLLMLogging] = None,
|
||||||
|
vertex_location: Optional[str] = None,
|
||||||
|
vertex_project: Optional[str] = None,
|
||||||
|
system_instruction: Optional[Dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
):
|
||||||
|
super().__init__(websocket, backend_ws, logging_obj)
|
||||||
|
self.model_id = model
|
||||||
|
self.config = config
|
||||||
|
self.vertex_location = vertex_location
|
||||||
|
self.vertex_project = vertex_project
|
||||||
|
self.system_instruction = system_instruction
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
# Track connection state manually
|
||||||
|
self.client_ws_open = True
|
||||||
|
self.backend_ws_open = True
|
||||||
|
|
||||||
|
if self.vertex_project and self.vertex_location:
|
||||||
|
self.model_resource_name = f"projects/{self.vertex_project}/locations/{self.vertex_location}/publishers/google/models/{self.model_id}"
|
||||||
|
else:
|
||||||
|
self.model_resource_name = self.model_id
|
||||||
|
print(f"Warning: vertex_project or vertex_location not provided. Using model_id directly: {self.model_resource_name}")
|
||||||
|
|
||||||
|
async def send_initial_setup(self):
|
||||||
|
"""Sends the initial setup message required by the Gemini API."""
|
||||||
|
setup_payload: Dict[str, Any] = {
|
||||||
|
"model": self.model_resource_name,
|
||||||
|
}
|
||||||
|
if self.config:
|
||||||
|
setup_payload["generation_config"] = self.config
|
||||||
|
if self.system_instruction:
|
||||||
|
setup_payload["system_instruction"] = self.system_instruction
|
||||||
|
|
||||||
|
# Add tools to the setup payload if they exist
|
||||||
|
if self.tools and len(self.tools) > 0:
|
||||||
|
setup_payload["tools"] = self.tools
|
||||||
|
|
||||||
|
setup_message = {"setup": setup_payload}
|
||||||
|
|
||||||
|
print(f"Gemini Setup Message: {json.dumps(setup_message)}")
|
||||||
|
await self.backend_ws.send(json.dumps(setup_message))
|
||||||
|
print("Gemini setup message sent.")
|
||||||
|
|
||||||
|
async def wait_for_setup_complete(self):
|
||||||
|
"""Waits for the setupComplete message from the Gemini backend."""
|
||||||
|
try:
|
||||||
|
response = await self.backend_ws.recv()
|
||||||
|
print(f"Setup response: {response}")
|
||||||
|
# Parse response to check if it's a valid setup completion
|
||||||
|
if isinstance(response, str):
|
||||||
|
response_data = json.loads(response)
|
||||||
|
if "setupComplete" not in response_data:
|
||||||
|
print(f"WARNING: Unexpected setup response format: {response}")
|
||||||
|
return True
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"Connection closed while waiting for Gemini setup complete: {e}")
|
||||||
|
await self.safely_close_websocket(self.websocket, code=e.code, reason=f"Backend connection closed during setup: {e.reason}")
|
||||||
|
self.backend_ws_open = False
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Failed to decode JSON during setup: {e}")
|
||||||
|
await self.safely_close_websocket(self.websocket, code=1011, reason="Invalid JSON received during setup")
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=1011, reason="Invalid JSON received during setup")
|
||||||
|
self.client_ws_open = False
|
||||||
|
self.backend_ws_open = False
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred during Gemini setup: {e}")
|
||||||
|
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Unexpected setup error: {e}")
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Unexpected setup error: {e}")
|
||||||
|
self.client_ws_open = False
|
||||||
|
self.backend_ws_open = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def safely_close_websocket(self, ws, code=1000, reason="Normal closure"):
|
||||||
|
"""Safely close a websocket without relying on .closed attribute"""
|
||||||
|
try:
|
||||||
|
await ws.close(code=code, reason=reason)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing websocket: {e}")
|
||||||
|
finally:
|
||||||
|
if ws == self.websocket:
|
||||||
|
self.client_ws_open = False
|
||||||
|
elif ws == self.backend_ws:
|
||||||
|
self.backend_ws_open = False
|
||||||
|
|
||||||
|
async def is_websocket_open(self, ws):
|
||||||
|
"""Check if websocket is open by trying a simple operation"""
|
||||||
|
try:
|
||||||
|
# For some websocket implementations, we can use an attribute
|
||||||
|
if hasattr(ws, 'closed'):
|
||||||
|
return not ws.closed
|
||||||
|
|
||||||
|
# For others, a state check might work
|
||||||
|
if hasattr(ws, 'state') and hasattr(ws.state, 'name'):
|
||||||
|
return ws.state.name == 'OPEN'
|
||||||
|
|
||||||
|
# Default: assume it's open if it's in our tracking variables
|
||||||
|
return (ws == self.websocket and self.client_ws_open) or \
|
||||||
|
(ws == self.backend_ws and self.backend_ws_open)
|
||||||
|
except Exception:
|
||||||
|
# If we can't determine, assume it's closed for safety
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def backend_to_client_send_messages(self):
|
||||||
|
"""Receives messages from Gemini, transforms them to LiteLLM format, and forwards to the client."""
|
||||||
|
try:
|
||||||
|
while self.backend_ws_open and self.client_ws_open:
|
||||||
|
message = await self.backend_ws.recv()
|
||||||
|
# Log the raw message received from Gemini for debugging
|
||||||
|
print(f"Received raw from Gemini backend: {message}")
|
||||||
|
|
||||||
|
# Store the original raw message for logging purposes
|
||||||
|
self.store_message(message)
|
||||||
|
|
||||||
|
transformed_message_str = None
|
||||||
|
try:
|
||||||
|
if isinstance(message, str):
|
||||||
|
gemini_data = json.loads(message)
|
||||||
|
|
||||||
|
# --- Transformation Logic ---
|
||||||
|
# Assume Gemini response structure (adjust if different)
|
||||||
|
# Example: {"candidates": [{"content": {"role": "model", "parts": [{"text": "..."}]}}]}
|
||||||
|
extracted_text = ""
|
||||||
|
lite_llm_role = "assistant" # Default LiteLLM role for model output
|
||||||
|
|
||||||
|
candidates = gemini_data.get("candidates")
|
||||||
|
if isinstance(candidates, list) and len(candidates) > 0:
|
||||||
|
# Process the first candidate
|
||||||
|
candidate = candidates[0]
|
||||||
|
content = candidate.get("content")
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Map Gemini's 'model' role to LiteLLM's 'assistant' role
|
||||||
|
if content.get("role") == "model":
|
||||||
|
lite_llm_role = "assistant"
|
||||||
|
else:
|
||||||
|
# Handle other potential roles if needed, or keep default
|
||||||
|
lite_llm_role = content.get("role", "assistant")
|
||||||
|
|
||||||
|
parts = content.get("parts")
|
||||||
|
if isinstance(parts, list) and len(parts) > 0:
|
||||||
|
# Concatenate text from all parts (or just take the first?)
|
||||||
|
# For simplicity, let's concatenate text parts.
|
||||||
|
text_parts = [part.get("text", "") for part in parts if isinstance(part, dict) and "text" in part]
|
||||||
|
extracted_text = "".join(text_parts)
|
||||||
|
|
||||||
|
# Add other potential extraction paths if Gemini's format varies
|
||||||
|
# For example, sometimes streaming responses might be simpler:
|
||||||
|
# elif "text" in gemini_data:
|
||||||
|
# extracted_text = gemini_data["text"]
|
||||||
|
# lite_llm_role = "assistant" # Assume model role for simple text
|
||||||
|
|
||||||
|
if extracted_text:
|
||||||
|
# Construct the LiteLLM standard message format using 'parts'
|
||||||
|
lite_llm_message = {
|
||||||
|
"role": lite_llm_role,
|
||||||
|
"parts": [{"text": extracted_text}]
|
||||||
|
# Alternatively, if your client prefers 'content':
|
||||||
|
# "content": extracted_text
|
||||||
|
}
|
||||||
|
transformed_message_str = json.dumps(lite_llm_message)
|
||||||
|
else:
|
||||||
|
# Handle non-content messages (e.g., metadata, finish reasons)
|
||||||
|
# Option 1: Forward them raw if the client needs them
|
||||||
|
# transformed_message_str = message
|
||||||
|
# Option 2: Skip forwarding non-content messages
|
||||||
|
print(f"No text content extracted from Gemini message, skipping forward: {message}")
|
||||||
|
continue # Skip to next message
|
||||||
|
|
||||||
|
elif isinstance(message, bytes):
|
||||||
|
# If Gemini sends bytes, decide how to handle (e.g., forward raw)
|
||||||
|
print("Received bytes from Gemini, forwarding directly.")
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.websocket.send_bytes(message)
|
||||||
|
continue # Skip JSON transformation for bytes
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Received unexpected message type from Gemini: {type(message)}")
|
||||||
|
continue # Skip unknown types
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Failed to decode JSON from Gemini: {message}")
|
||||||
|
# Decide how to handle non-JSON messages (e.g., forward raw string)
|
||||||
|
if isinstance(message, str) and await self.is_websocket_open(self.websocket):
|
||||||
|
print("Forwarding non-JSON string message raw.")
|
||||||
|
await self.websocket.send_text(message)
|
||||||
|
continue # Skip processing if not JSON
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing/transforming message from Gemini: {e}")
|
||||||
|
# Decide how to handle errors (e.g., skip message)
|
||||||
|
continue # Skip message on transformation error
|
||||||
|
|
||||||
|
|
||||||
|
# Send the transformed message (if available) to the client
|
||||||
|
if transformed_message_str and await self.is_websocket_open(self.websocket):
|
||||||
|
print(f"Sending transformed to client: {transformed_message_str}")
|
||||||
|
await self.websocket.send_text(transformed_message_str)
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"Gemini backend connection closed: {e.code} {e.reason}")
|
||||||
|
self.backend_ws_open = False
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.safely_close_websocket(self.websocket, code=e.code, reason=e.reason)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error receiving from Gemini backend or sending to client: {e}")
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Error forwarding message: {e}")
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Error forwarding message: {e}")
|
||||||
|
finally:
|
||||||
|
# Log accumulated messages if needed (self.log_messages() might need adjustment
|
||||||
|
# if it relies on the stored messages being in a specific format)
|
||||||
|
# await self.log_messages()
|
||||||
|
print("Backend-to-client message forwarding stopped.")
|
||||||
|
|
||||||
|
|
||||||
|
async def client_to_backend_send_messages(self):
|
||||||
|
"""Receives messages from the client (e.g., Twilio app), formats them
|
||||||
|
correctly for Gemini, and forwards to the Gemini backend."""
|
||||||
|
try:
|
||||||
|
while self.client_ws_open and self.backend_ws_open:
|
||||||
|
message_text = await self.websocket.receive_text()
|
||||||
|
print(f"Received from client: {message_text}")
|
||||||
|
|
||||||
|
# Parse the message to check if it has the correct format
|
||||||
|
try:
|
||||||
|
message_data = json.loads(message_text)
|
||||||
|
final_message_to_send = None # Store the final formatted message dict
|
||||||
|
|
||||||
|
# --- START AUDIO HANDLING ---
|
||||||
|
# Check if the message is audio input from the client (Twilio script)
|
||||||
|
if "audio" in message_data and "type" in message_data and message_data["type"] == "input_audio_buffer.append":
|
||||||
|
# Construct the message using the realtimeInput format
|
||||||
|
# based on the provided AudioInputMessage example.
|
||||||
|
|
||||||
|
audio_data_base64 = message_data["audio"]
|
||||||
|
|
||||||
|
# Determine MIME type based on Twilio setup.
|
||||||
|
# twilio_example_litellm.py sets "input_audio_format": "g711_ulaw".
|
||||||
|
# The standard MIME type for G.711 µ-law is audio/mulaw or audio/pcmu.
|
||||||
|
# It typically runs at 8000 Hz.
|
||||||
|
sample_rate = 8000
|
||||||
|
# Use the format from the example: audio/pcm;rate=... or audio/mulaw
|
||||||
|
# Let's try the specific mulaw type first.
|
||||||
|
# mime_type = f"audio/pcm;rate={sample_rate}" # As per example
|
||||||
|
mime_type = "audio/mulaw" # Standard for G.711 µ-law
|
||||||
|
|
||||||
|
# Structure according to the RealtimeInput/MediaChunk model example
|
||||||
|
# Use the exact field names 'mimeType' and 'mediaChunks'
|
||||||
|
media_chunk = {
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"data": audio_data_base64
|
||||||
|
}
|
||||||
|
realtime_input_payload = {
|
||||||
|
"mediaChunks": [media_chunk]
|
||||||
|
}
|
||||||
|
final_message_to_send = {"realtimeInput": realtime_input_payload}
|
||||||
|
# --- END AUDIO HANDLING ---
|
||||||
|
|
||||||
|
# --- START OTHER MESSAGE HANDLING (Text, Tools, etc.) ---
|
||||||
|
# Handle text/history messages potentially coming in {"contents": [...]} format
|
||||||
|
elif "contents" in message_data and isinstance(message_data.get("contents"), list):
|
||||||
|
# Adapt the incoming 'contents' list (assumed to be turns) to the 'clientContent' format
|
||||||
|
content_turns = message_data["contents"]
|
||||||
|
valid_turns = []
|
||||||
|
for turn in content_turns:
|
||||||
|
if isinstance(turn, dict) and "role" in turn and "parts" in turn:
|
||||||
|
valid_turns.append(turn)
|
||||||
|
else:
|
||||||
|
print(f"WARNING: Skipping invalid turn structure in 'contents': {turn}")
|
||||||
|
|
||||||
|
if valid_turns:
|
||||||
|
is_complete = message_data.get("turn_complete", True)
|
||||||
|
final_message_to_send = {"clientContent": {"turns": valid_turns, "turn_complete": is_complete}}
|
||||||
|
else:
|
||||||
|
print(f"WARNING: No valid turns found in 'contents', cannot send message.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For tool response messages, assume client sends correct format
|
||||||
|
elif "toolResponse" in message_data:
|
||||||
|
final_message_to_send = message_data # Pass through directly
|
||||||
|
|
||||||
|
# Handle potential cleanup of unsupported fields if message wasn't reformatted
|
||||||
|
elif final_message_to_send is None:
|
||||||
|
# If it wasn't audio, text, or tool response, check for and remove common unsupported fields
|
||||||
|
# before potentially forwarding (or deciding not to forward)
|
||||||
|
unsupported_fields = ["type", "session"]
|
||||||
|
cleaned_data = {k: v for k, v in message_data.items() if k not in unsupported_fields}
|
||||||
|
if cleaned_data != message_data:
|
||||||
|
print(f"WARNING: Removed unsupported fields. Result: {cleaned_data}")
|
||||||
|
# Decide if this cleaned_data is a valid Gemini message (e.g., setup?)
|
||||||
|
# For now, let's assume if it wasn't handled above, it's not valid to send.
|
||||||
|
# If you need to handle other types like 'setup', add specific elif blocks.
|
||||||
|
print(f"WARNING: Message from client is not a recognized/handled format: {message_data}")
|
||||||
|
continue # Skip sending unrecognized formats
|
||||||
|
# --- END OTHER MESSAGE HANDLING ---
|
||||||
|
|
||||||
|
# Convert the final formatted message back to string for sending
|
||||||
|
if final_message_to_send:
|
||||||
|
message_to_send_str = json.dumps(final_message_to_send)
|
||||||
|
else:
|
||||||
|
# Should not happen if logic is correct, but as a fallback
|
||||||
|
print(f"ERROR: final_message_to_send is None after processing.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"WARNING: Received non-JSON message from client, cannot process for Gemini: {message_text}")
|
||||||
|
continue # Skip non-JSON messages
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed processing client message: {e}")
|
||||||
|
continue # Skip sending this message on processing error
|
||||||
|
|
||||||
|
self.store_input(message=final_message_to_send)
|
||||||
|
print(f"Sending to Gemini backend: {message_to_send_str}")
|
||||||
|
await self.backend_ws.send(message_to_send_str)
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"Client connection closed: {e.code} {e.reason}")
|
||||||
|
self.client_ws_open = False
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=e.code, reason=e.reason)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error receiving from client or sending to Gemini backend: {e}")
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Error forwarding message: {e}")
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Error forwarding message: {e}")
|
||||||
|
finally:
|
||||||
|
print("Client-to-backend message forwarding stopped.")
|
||||||
|
async def bidirectional_forward(self):
|
||||||
|
"""Orchestrates the Gemini WebSocket session: setup and message forwarding."""
|
||||||
|
try:
|
||||||
|
await self.send_initial_setup()
|
||||||
|
|
||||||
|
setup_ok = await self.wait_for_setup_complete()
|
||||||
|
if not setup_ok:
|
||||||
|
print("Gemini setup failed. Aborting bidirectional forward.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Gemini setup successful. Starting bidirectional message forwarding.")
|
||||||
|
|
||||||
|
client_to_backend_task = asyncio.create_task(self.client_to_backend_send_messages())
|
||||||
|
backend_to_client_task = asyncio.create_task(self.backend_to_client_send_messages())
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[client_to_backend_task, backend_to_client_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print("Bidirectional forwarding finished.")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"A connection closed unexpectedly during bidirectional forward setup or task management: {e}")
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.safely_close_websocket(self.websocket, code=e.code, reason=f"Peer connection closed: {e.reason}")
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=e.code, reason=f"Peer connection closed: {e.reason}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred in bidirectional_forward: {e}")
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
await self.safely_close_websocket(self.websocket, code=1011, reason=f"Forwarding error: {e}")
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
await self.safely_close_websocket(self.backend_ws, code=1011, reason=f"Forwarding error: {e}")
|
||||||
|
finally:
|
||||||
|
if await self.is_websocket_open(self.websocket):
|
||||||
|
print("Closing client websocket in finally block.")
|
||||||
|
await self.safely_close_websocket(self.websocket)
|
||||||
|
if await self.is_websocket_open(self.backend_ws):
|
||||||
|
print("Closing backend websocket in finally block.")
|
||||||
|
await self.safely_close_websocket(self.backend_ws)
|
||||||
|
print("bidirectional_forward cleanup complete.")
|
125
litellm/llms/vertex_ai/multimodal_live/handler.py
Normal file
125
litellm/llms/vertex_ai/multimodal_live/handler.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
"""
|
||||||
|
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||||
|
|
||||||
|
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
import os
|
||||||
|
|
||||||
|
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
from ....litellm_core_utils.gemini_realtime_streaming import GeminiRealTimeStreaming
|
||||||
|
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM, VertexGeminiConfig
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiLive(VertexLLM):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _construct_url(self, api_base: str) -> str:
|
||||||
|
"""
|
||||||
|
Example output:
|
||||||
|
"BACKEND_WS_URL = "wss://localhost:8080/ws/google.cloud.aiplatform.v1beta1.LlmBidiService/BidiGenerateContent"";
|
||||||
|
"""
|
||||||
|
api_base = api_base.replace("https://", "wss://")
|
||||||
|
api_base = api_base.replace("http://", "ws://")
|
||||||
|
return f"{api_base}/ws/google.cloud.aiplatform.v1beta1.LlmBidiService/BidiGenerateContent"
|
||||||
|
|
||||||
|
async def _send_setup_message(self, ws: Any, model: str, config: dict):
|
||||||
|
"""
|
||||||
|
Sends the initial setup message required by the Gemini realtime endpoint.
|
||||||
|
"""
|
||||||
|
setup_payload = {
|
||||||
|
"setup": {
|
||||||
|
"model": model,
|
||||||
|
"generation_config": config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await ws.send(json.dumps(setup_payload))
|
||||||
|
|
||||||
|
async def async_realtime(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
websocket: Any,
|
||||||
|
logging_obj: LiteLLMLogging,
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
optional_params: dict,
|
||||||
|
vertex_credentials_path: str,
|
||||||
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
client: Optional[Any] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
voice_name: Optional[str] = "Aoede",
|
||||||
|
vertex_project: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Websockets package not installed. Please install it with `pip install websockets`")
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for Gemini calls")
|
||||||
|
|
||||||
|
url = self._construct_url(api_base)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"response_modalities": ["AUDIO"],
|
||||||
|
"speech_config": {
|
||||||
|
"voice_config": {"prebuilt_voice_config": {"voice_name": voice_name}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(vertex_credentials_path, 'r') as f:
|
||||||
|
vertex_credentials = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to load credentials: {str(e)}")
|
||||||
|
|
||||||
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = VertexGeminiConfig().validate_environment(
|
||||||
|
api_key=_auth_header,
|
||||||
|
headers=extra_headers,
|
||||||
|
model=model,
|
||||||
|
messages=[],
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with websockets.connect( # type: ignore
|
||||||
|
url,
|
||||||
|
extra_headers=headers,
|
||||||
|
) as backend_ws:
|
||||||
|
# await self._send_setup_message(backend_ws, model, config)
|
||||||
|
realtime_streaming = GeminiRealTimeStreaming(
|
||||||
|
websocket, backend_ws, model, config, logging_obj, vertex_location, vertex_project)
|
||||||
|
await realtime_streaming.bidirectional_forward()
|
||||||
|
|
||||||
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
|
await websocket.close(code=e.status_code, reason=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
await websocket.close(
|
||||||
|
code=1011, reason=f"Internal server error: {str(e)}"
|
||||||
|
)
|
||||||
|
except RuntimeError as close_error:
|
||||||
|
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||||
|
close_error
|
||||||
|
):
|
||||||
|
# The WebSocket is already closed or the response is completed, so we can
|
||||||
|
# ignore this error
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# If it's a different RuntimeError, we might want to log it or handle it
|
||||||
|
# differently
|
||||||
|
raise Exception(
|
||||||
|
f"Unexpected error while closing WebSocket: {close_error}"
|
||||||
|
)
|
|
@ -11,10 +11,12 @@ from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
||||||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
||||||
from ..llms.openai.realtime.handler import OpenAIRealtime
|
from ..llms.openai.realtime.handler import OpenAIRealtime
|
||||||
|
from ..llms.vertex_ai.multimodal_live.handler import GeminiLive # Import the new handler
|
||||||
from ..utils import client as wrapper_client
|
from ..utils import client as wrapper_client
|
||||||
|
|
||||||
azure_realtime = AzureOpenAIRealtime()
|
azure_realtime = AzureOpenAIRealtime()
|
||||||
openai_realtime = OpenAIRealtime()
|
openai_realtime = OpenAIRealtime()
|
||||||
|
gemini_realtime = GeminiLive() # Instantiate the Gemini handler
|
||||||
|
|
||||||
|
|
||||||
@wrapper_client
|
@wrapper_client
|
||||||
|
@ -104,6 +106,31 @@ async def _arealtime(
|
||||||
client=None,
|
client=None,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
elif _custom_llm_provider == "vertex_ai_beta" or _custom_llm_provider == "vertex_ai": # Add the Gemini case
|
||||||
|
api_base = (
|
||||||
|
dynamic_api_base
|
||||||
|
or litellm_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("GEMINI_API_BASE")
|
||||||
|
# default base for vertexs
|
||||||
|
or "https://us-central1-aiplatform.googleapis.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await gemini_realtime.async_realtime(
|
||||||
|
model=model,
|
||||||
|
websocket=websocket,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
optional_params={},
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
vertex_location=litellm_params.vertex_location, # Add default vertex location
|
||||||
|
vertex_credentials_path=str(litellm_params.vertex_credentials), # Add default vertex credentials
|
||||||
|
vertex_project=litellm_params.vertex_project, # Add default vertex project
|
||||||
|
custom_llm_provider=_custom_llm_provider, # Add custom llm provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to connect to Gemini realtime API: {e}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model: {model}")
|
raise ValueError(f"Unsupported model: {model}")
|
||||||
|
|
||||||
|
@ -143,8 +170,16 @@ async def _realtime_health_check(
|
||||||
url = openai_realtime._construct_url(
|
url = openai_realtime._construct_url(
|
||||||
api_base=api_base or "https://api.openai.com/", model=model
|
api_base=api_base or "https://api.openai.com/", model=model
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "gemini": # Add Gemini case
|
||||||
|
url = gemini_realtime._construct_url(
|
||||||
|
api_base=api_base or "https://generativelanguage.googleapis.com",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model: {model}")
|
raise ValueError(f"Unsupported model: {model}")
|
||||||
|
|
||||||
|
if url is None:
|
||||||
|
raise ValueError("Failed to construct WebSocket URL")
|
||||||
|
|
||||||
async with websockets.connect( # type: ignore
|
async with websockets.connect( # type: ignore
|
||||||
url,
|
url,
|
||||||
extra_headers={
|
extra_headers={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue