diff --git a/docs/my-website/docs/proxy/health.md b/docs/my-website/docs/proxy/health.md index 585e2ff750..0da4716dcb 100644 --- a/docs/my-website/docs/proxy/health.md +++ b/docs/my-website/docs/proxy/health.md @@ -168,6 +168,20 @@ Expected Response } ``` +### Realtime Models + +To run realtime health checks, specify the mode as "realtime" in your config for the relevant model. + +```yaml +model_list: + - model_name: openai/gpt-4o-realtime-audio + litellm_params: + model: openai/gpt-4o-realtime-audio + api_key: os.environ/OPENAI_API_KEY + model_info: + mode: realtime +``` + ## Background Health Checks You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`. diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index f7110210d3..72dcd59abf 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -1585,6 +1585,18 @@ class AzureChatCompletion(BaseLLM): ) elif mode == "batch": completion = await client.batches.with_raw_response.list(limit=1) # type: ignore + elif mode == "realtime": + from litellm.realtime_api.main import _realtime_health_check + + # create a websocket connection + await _realtime_health_check( + model=model or "", + api_key=api_key, + api_base=api_base, + api_version=api_version, + custom_llm_provider="azure", + ) + return {} else: raise Exception("mode not set") response = {} diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 31973796ef..f0045d9aa4 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1374,6 +1374,17 @@ class OpenAIChatCompletion(BaseLLM): input=prompt, # type: ignore voice="alloy", ) + elif mode == "realtime": + from litellm.realtime_api.main import _realtime_health_check + + # create a websocket connection + await _realtime_health_check( + model=model or "", + api_key=api_key, + api_base=api_base or "https://api.openai.com/", + custom_llm_provider="openai", + ) + return {} else: raise ValueError("mode not set, passed in mode: " + mode) response = {} diff --git a/litellm/main.py b/litellm/main.py index 39a9873cf7..6b0f7c026f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5153,7 +5153,13 @@ async def ahealth_check( # noqa: PLR0915 model_params: dict, mode: Optional[ Literal[ - "completion", "embedding", "image_generation", "chat", "batch", "rerank" + "completion", + "embedding", + "image_generation", + "chat", + "batch", + "rerank", + "realtime", ] ] = None, prompt: Optional[str] = None, diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index 268351874d..7afd526c45 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -114,3 +114,45 @@ async def _arealtime( ) else: raise ValueError(f"Unsupported model: {model}") + + +async def _realtime_health_check( + model: str, + api_base: str, + custom_llm_provider: str, + api_key: Optional[str], + api_version: Optional[str] = None, +): + """ + Health check for realtime API - tries connection to the realtime API websocket + + Args: + model: str - model name + api_base: str - api base + api_version: Optional[str] - api version + api_key: str - api key + custom_llm_provider: str - custom llm provider + + Returns: + bool - True if connection is successful, False otherwise + Raises: + Exception - if the connection is not successful + """ + import websockets + + url: Optional[str] = None + if custom_llm_provider == "azure": + url = azure_realtime._construct_url( + api_base=api_base, + model=model, + api_version=api_version or "2024-10-01-preview", + ) + elif custom_llm_provider == "openai": + url = openai_realtime._construct_url(api_base=api_base, model=model) + async with websockets.connect( # type: ignore + url, + extra_headers={ + "api-key": api_key, # type: ignore + }, + ): + return True diff --git a/tests/local_testing/test_health_check.py b/tests/local_testing/test_health_check.py index b0e8d1c3b0..3535a4fe94 100644 --- a/tests/local_testing/test_health_check.py +++ b/tests/local_testing/test_health_check.py @@ -152,3 +152,27 @@ async def test_cohere_rerank_health_check(): assert "error" not in response print(response) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model", ["azure/gpt-4o-realtime-preview", "openai/gpt-4o-realtime-preview"] +) +async def test_realtime_health_check(model): + """ + Test Health Check with Valid models passes + + """ + model_params = { + "model": model, + } + if model == "azure/gpt-4o-realtime-preview": + model_params["api_base"] = os.getenv("AZURE_REALTIME_API_BASE") + model_params["api_key"] = os.getenv("AZURE_REALTIME_API_KEY") + model_params["api_version"] = os.getenv("AZURE_REALTIME_API_VERSION") + response = await litellm.ahealth_check( + model_params=model_params, + mode="realtime", + ) + print(response) + assert response == {}