This commit is contained in:
Ankur Duggal 2025-04-24 00:56:58 -07:00 committed by GitHub
commit ffb42e3973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 559 additions and 1 deletions

View file

@ -11,10 +11,12 @@ from ..litellm_core_utils.get_litellm_params import get_litellm_params
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.openai.realtime.handler import OpenAIRealtime
from ..llms.vertex_ai.multimodal_live.handler import GeminiLive # Import the new handler
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
gemini_realtime = GeminiLive() # Instantiate the Gemini handler
@wrapper_client
@ -104,6 +106,31 @@ async def _arealtime(
client=None,
timeout=timeout,
)
elif _custom_llm_provider == "vertex_ai_beta" or _custom_llm_provider == "vertex_ai": # Add the Gemini case
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
or get_secret_str("GEMINI_API_BASE")
# default base for vertexs
or "https://us-central1-aiplatform.googleapis.com"
)
try:
await gemini_realtime.async_realtime(
model=model,
websocket=websocket,
api_base=api_base,
timeout=timeout,
optional_params={},
logging_obj=litellm_logging_obj,
vertex_location=litellm_params.vertex_location, # Add default vertex location
vertex_credentials_path=str(litellm_params.vertex_credentials), # Add default vertex credentials
vertex_project=litellm_params.vertex_project, # Add default vertex project
custom_llm_provider=_custom_llm_provider, # Add custom llm provider
)
except Exception as e:
raise ValueError(f"Failed to connect to Gemini realtime API: {e}")
else:
raise ValueError(f"Unsupported model: {model}")
@ -143,8 +170,16 @@ async def _realtime_health_check(
url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/", model=model
)
elif custom_llm_provider == "gemini": # Add Gemini case
url = gemini_realtime._construct_url(
api_base=api_base or "https://generativelanguage.googleapis.com",
)
else:
raise ValueError(f"Unsupported model: {model}")
if url is None:
raise ValueError("Failed to construct WebSocket URL")
async with websockets.connect( # type: ignore
url,
extra_headers={