Gemini Live Integration

This commit is contained in:
Ankur Duggal 2025-03-25 17:01:34 -07:00
parent e8c4cd8c1a
commit 877e8b0498
5 changed files with 255 additions and 2 deletions

View file

@ -10,10 +10,12 @@ from litellm.types.router import GenericLiteLLMParams
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.openai.realtime.handler import OpenAIRealtime
from ..llms.vertex_ai.multimodal_live.handler import GeminiLive # Import the new handler
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
gemini_realtime = GeminiLive() # Instantiate the Gemini handler
@wrapper_client
@ -112,10 +114,28 @@ async def _arealtime(
client=None,
timeout=timeout,
)
elif _custom_llm_provider == "gemini" or _custom_llm_provider == "vertex_ai_beta" or _custom_llm_provider == "vertex_ai": # Add the Gemini case
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
# default base for vertex
or get_secret_str("GEMINI_API_BASE") or "https://us-central1-aiplatform.googleapis.com"
)
await gemini_realtime.async_realtime(
model=model,
websocket=websocket,
api_base=api_base,
client=None,
timeout=timeout,
logging_obj=litellm_logging_obj,
vertex_location="us-central1", # Add default vertex location
optional_params={}, # Add empty optional params
custom_llm_provider=_custom_llm_provider, # Add custom llm provider
)
else:
raise ValueError(f"Unsupported model: {model}")
async def _realtime_health_check(
model: str,
custom_llm_provider: str,
@ -151,8 +171,16 @@ async def _realtime_health_check(
url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/", model=model
)
elif custom_llm_provider == "gemini": # Add Gemini case
url = gemini_realtime._construct_url(
api_base=api_base or "https://generativelanguage.googleapis.com",
)
else:
raise ValueError(f"Unsupported model: {model}")
if url is None:
raise ValueError("Failed to construct WebSocket URL")
async with websockets.connect( # type: ignore
url,
extra_headers={