mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
189 lines
6 KiB
Python
189 lines
6 KiB
Python
"""Abstraction function for OpenAI's realtime API"""
|
|
|
|
from typing import Any, Optional
|
|
|
|
import litellm
|
|
from litellm import get_llm_provider
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.router import GenericLiteLLMParams
|
|
|
|
from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
|
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
|
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
|
from ..llms.openai.realtime.handler import OpenAIRealtime
|
|
from ..llms.vertex_ai.multimodal_live.handler import GeminiLive # Import the new handler
|
|
from ..utils import client as wrapper_client
|
|
|
|
azure_realtime = AzureOpenAIRealtime()
|
|
openai_realtime = OpenAIRealtime()
|
|
gemini_realtime = GeminiLive() # Instantiate the Gemini handler
|
|
|
|
|
|
@wrapper_client
|
|
async def _arealtime(
|
|
model: str,
|
|
websocket: Any, # fastapi websocket
|
|
api_base: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
azure_ad_token: Optional[str] = None,
|
|
client: Optional[Any] = None,
|
|
timeout: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Private function to handle the realtime API call.
|
|
|
|
For PROXY use only.
|
|
"""
|
|
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
|
user = kwargs.get("user", None)
|
|
litellm_params = GenericLiteLLMParams(**kwargs)
|
|
|
|
litellm_params_dict = get_litellm_params(**kwargs)
|
|
|
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
|
model=model,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
)
|
|
|
|
litellm_logging_obj.update_environment_variables(
|
|
model=model,
|
|
user=user,
|
|
optional_params={},
|
|
litellm_params=litellm_params_dict,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
)
|
|
|
|
if _custom_llm_provider == "azure":
|
|
api_base = (
|
|
dynamic_api_base
|
|
or litellm_params.api_base
|
|
or litellm.api_base
|
|
or get_secret_str("AZURE_API_BASE")
|
|
)
|
|
# set API KEY
|
|
api_key = (
|
|
dynamic_api_key
|
|
or litellm.api_key
|
|
or litellm.openai_key
|
|
or get_secret_str("AZURE_API_KEY")
|
|
)
|
|
|
|
await azure_realtime.async_realtime(
|
|
model=model,
|
|
websocket=websocket,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
api_version="2024-10-01-preview",
|
|
azure_ad_token=None,
|
|
client=None,
|
|
timeout=timeout,
|
|
logging_obj=litellm_logging_obj,
|
|
)
|
|
elif _custom_llm_provider == "openai":
|
|
api_base = (
|
|
dynamic_api_base
|
|
or litellm_params.api_base
|
|
or litellm.api_base
|
|
or "https://api.openai.com/"
|
|
)
|
|
# set API KEY
|
|
api_key = (
|
|
dynamic_api_key
|
|
or litellm.api_key
|
|
or litellm.openai_key
|
|
or get_secret_str("OPENAI_API_KEY")
|
|
)
|
|
|
|
await openai_realtime.async_realtime(
|
|
model=model,
|
|
websocket=websocket,
|
|
logging_obj=litellm_logging_obj,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
client=None,
|
|
timeout=timeout,
|
|
)
|
|
elif _custom_llm_provider == "vertex_ai_beta" or _custom_llm_provider == "vertex_ai": # Add the Gemini case
|
|
api_base = (
|
|
dynamic_api_base
|
|
or litellm_params.api_base
|
|
or litellm.api_base
|
|
or get_secret_str("GEMINI_API_BASE")
|
|
# default base for vertexs
|
|
or "https://us-central1-aiplatform.googleapis.com"
|
|
)
|
|
|
|
try:
|
|
await gemini_realtime.async_realtime(
|
|
model=model,
|
|
websocket=websocket,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
optional_params={},
|
|
logging_obj=litellm_logging_obj,
|
|
vertex_location=litellm_params.vertex_location, # Add default vertex location
|
|
vertex_credentials_path=str(litellm_params.vertex_credentials), # Add default vertex credentials
|
|
vertex_project=litellm_params.vertex_project, # Add default vertex project
|
|
custom_llm_provider=_custom_llm_provider, # Add custom llm provider
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to connect to Gemini realtime API: {e}")
|
|
else:
|
|
raise ValueError(f"Unsupported model: {model}")
|
|
|
|
|
|
async def _realtime_health_check(
|
|
model: str,
|
|
custom_llm_provider: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
):
|
|
"""
|
|
Health check for realtime API - tries connection to the realtime API websocket
|
|
|
|
Args:
|
|
model: str - model name
|
|
api_base: str - api base
|
|
api_version: Optional[str] - api version
|
|
api_key: str - api key
|
|
custom_llm_provider: str - custom llm provider
|
|
|
|
Returns:
|
|
bool - True if connection is successful, False otherwise
|
|
Raises:
|
|
Exception - if the connection is not successful
|
|
"""
|
|
import websockets
|
|
|
|
url: Optional[str] = None
|
|
if custom_llm_provider == "azure":
|
|
url = azure_realtime._construct_url(
|
|
api_base=api_base or "",
|
|
model=model,
|
|
api_version=api_version or "2024-10-01-preview",
|
|
)
|
|
elif custom_llm_provider == "openai":
|
|
url = openai_realtime._construct_url(
|
|
api_base=api_base or "https://api.openai.com/", model=model
|
|
)
|
|
elif custom_llm_provider == "gemini": # Add Gemini case
|
|
url = gemini_realtime._construct_url(
|
|
api_base=api_base or "https://generativelanguage.googleapis.com",
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported model: {model}")
|
|
|
|
if url is None:
|
|
raise ValueError("Failed to construct WebSocket URL")
|
|
|
|
async with websockets.connect( # type: ignore
|
|
url,
|
|
extra_headers={
|
|
"api-key": api_key, # type: ignore
|
|
},
|
|
):
|
|
return True
|