diff --git a/.circleci/config.yml b/.circleci/config.yml index 5d132dd51d..032f697c78 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -69,6 +69,7 @@ jobs: pip install "Pillow==10.3.0" pip install "jsonschema==4.22.0" pip install "pytest-xdist==3.6.1" + pip install "websockets==10.4" - save_cache: paths: - ./venv diff --git a/litellm/litellm_core_utils/audio_utils/audio_health_check.wav b/litellm/litellm_core_utils/audio_utils/audio_health_check.wav new file mode 100644 index 0000000000..f766d23e94 Binary files /dev/null and b/litellm/litellm_core_utils/audio_utils/audio_health_check.wav differ diff --git a/litellm/litellm_core_utils/audio_utils/utils.py b/litellm/litellm_core_utils/audio_utils/utils.py index ab19dac9cc..8018fe1153 100644 --- a/litellm/litellm_core_utils/audio_utils/utils.py +++ b/litellm/litellm_core_utils/audio_utils/utils.py @@ -2,6 +2,8 @@ Utils used for litellm.transcription() and litellm.atranscription() """ +import os + from litellm.types.utils import FileTypes @@ -21,3 +23,14 @@ def get_audio_file_name(file_obj: FileTypes) -> str: return str(file_obj) else: return repr(file_obj) + + +def get_audio_file_for_health_check() -> FileTypes: + """ + Get an audio file for health check + + Returns the content of `audio_health_check.wav` in the same directory as this file + """ + pwd = os.path.dirname(os.path.realpath(__file__)) + file_path = os.path.join(pwd, "audio_health_check.wav") + return open(file_path, "rb") diff --git a/litellm/litellm_core_utils/health_check_utils.py b/litellm/litellm_core_utils/health_check_utils.py new file mode 100644 index 0000000000..ff252855f0 --- /dev/null +++ b/litellm/litellm_core_utils/health_check_utils.py @@ -0,0 +1,28 @@ +""" +Utils used for litellm.ahealth_check() +""" + + +def _filter_model_params(model_params: dict) -> dict: + """Remove 'messages' param from model params.""" + return {k: v for k, v in model_params.items() if k != "messages"} + + +def _create_health_check_response(response_headers: dict) -> dict: + response = {} + + if ( + response_headers.get("x-ratelimit-remaining-requests", None) is not None + ): # not provided for dall-e requests + response["x-ratelimit-remaining-requests"] = response_headers[ + "x-ratelimit-remaining-requests" + ] + + if response_headers.get("x-ratelimit-remaining-tokens", None) is not None: + response["x-ratelimit-remaining-tokens"] = response_headers[ + "x-ratelimit-remaining-tokens" + ] + + if response_headers.get("x-ms-region", None) is not None: + response["x-ms-region"] = response_headers["x-ms-region"] + return response diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 72dcd59abf..f771532133 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -1491,132 +1491,3 @@ class AzureChatCompletion(BaseLLM): response["x-ms-region"] = completion.headers["x-ms-region"] return response - - async def ahealth_check( - self, - model: Optional[str], - api_key: Optional[str], - api_base: str, - api_version: Optional[str], - timeout: float, - mode: str, - messages: Optional[list] = None, - input: Optional[list] = None, - prompt: Optional[str] = None, - ) -> dict: - client_session = ( - litellm.aclient_session - or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client - ) # handle dall-e-2 calls - - if "gateway.ai.cloudflare.com" in api_base: - ## build base url - assume api base includes resource name - if not api_base.endswith("/"): - api_base += "/" - api_base += f"{model}" - client = AsyncAzureOpenAI( - base_url=api_base, - api_version=api_version, - api_key=api_key, - timeout=timeout, - http_client=client_session, - ) - model = None - # cloudflare ai gateway, needs model=None - else: - client = AsyncAzureOpenAI( - api_version=api_version, - azure_endpoint=api_base, - api_key=api_key, - timeout=timeout, - http_client=client_session, - ) - - # only run this check if it's not cloudflare ai gateway - if model is None and mode != "image_generation": - raise Exception("model is not set") - - completion = None - - if mode == "completion": - completion = await client.completions.with_raw_response.create( - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "chat": - if messages is None: - raise Exception("messages is not set") - completion = await client.chat.completions.with_raw_response.create( - model=model, # type: ignore - messages=messages, # type: ignore - ) - elif mode == "embedding": - if input is None: - raise Exception("input is not set") - completion = await client.embeddings.with_raw_response.create( - model=model, # type: ignore - input=input, # type: ignore - ) - elif mode == "image_generation": - if prompt is None: - raise Exception("prompt is not set") - completion = await client.images.with_raw_response.generate( - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "audio_transcription": - # Get the current directory of the file being run - pwd = os.path.dirname(os.path.realpath(__file__)) - file_path = os.path.join( - pwd, "../../../tests/gettysburg.wav" - ) # proxy address - audio_file = open(file_path, "rb") - completion = await client.audio.transcriptions.with_raw_response.create( - file=audio_file, - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "audio_speech": - # Get the current directory of the file being run - completion = await client.audio.speech.with_raw_response.create( - model=model, # type: ignore - input=prompt, # type: ignore - voice="alloy", - ) - elif mode == "batch": - completion = await client.batches.with_raw_response.list(limit=1) # type: ignore - elif mode == "realtime": - from litellm.realtime_api.main import _realtime_health_check - - # create a websocket connection - await _realtime_health_check( - model=model or "", - api_key=api_key, - api_base=api_base, - api_version=api_version, - custom_llm_provider="azure", - ) - return {} - else: - raise Exception("mode not set") - response = {} - - if completion is None or not hasattr(completion, "headers"): - raise Exception("invalid completion response") - - if ( - completion.headers.get("x-ratelimit-remaining-requests", None) is not None - ): # not provided for dall-e requests - response["x-ratelimit-remaining-requests"] = completion.headers[ - "x-ratelimit-remaining-requests" - ] - - if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: - response["x-ratelimit-remaining-tokens"] = completion.headers[ - "x-ratelimit-remaining-tokens" - ] - - if completion.headers.get("x-ms-region", None) is not None: - response["x-ms-region"] = completion.headers["x-ms-region"] - - return response diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index f0045d9aa4..0ee8e3dadd 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1,5 +1,4 @@ import hashlib -import os import types from typing import ( Any, @@ -1306,105 +1305,6 @@ class OpenAIChatCompletion(BaseLLM): return HttpxBinaryResponseContent(response=response.response) - async def ahealth_check( - self, - model: Optional[str], - api_key: Optional[str], - timeout: float, - mode: str, - messages: Optional[list] = None, - input: Optional[list] = None, - prompt: Optional[str] = None, - organization: Optional[str] = None, - api_base: Optional[str] = None, - ): - client = AsyncOpenAI( - api_key=api_key, - timeout=timeout, - organization=organization, - base_url=api_base, - ) - if model is None and mode != "image_generation": - raise Exception("model is not set") - - completion = None - - if mode == "completion": - completion = await client.completions.with_raw_response.create( - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "chat": - if messages is None: - raise Exception("messages is not set") - completion = await client.chat.completions.with_raw_response.create( - model=model, # type: ignore - messages=messages, # type: ignore - ) - elif mode == "embedding": - if input is None: - raise Exception("input is not set") - completion = await client.embeddings.with_raw_response.create( - model=model, # type: ignore - input=input, # type: ignore - ) - elif mode == "image_generation": - if prompt is None: - raise Exception("prompt is not set") - completion = await client.images.with_raw_response.generate( - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "audio_transcription": - # Get the current directory of the file being run - pwd = os.path.dirname(os.path.realpath(__file__)) - file_path = os.path.join( - pwd, "../../../tests/gettysburg.wav" - ) # proxy address - audio_file = open(file_path, "rb") - completion = await client.audio.transcriptions.with_raw_response.create( - file=audio_file, - model=model, # type: ignore - prompt=prompt, # type: ignore - ) - elif mode == "audio_speech": - # Get the current directory of the file being run - completion = await client.audio.speech.with_raw_response.create( - model=model, # type: ignore - input=prompt, # type: ignore - voice="alloy", - ) - elif mode == "realtime": - from litellm.realtime_api.main import _realtime_health_check - - # create a websocket connection - await _realtime_health_check( - model=model or "", - api_key=api_key, - api_base=api_base or "https://api.openai.com/", - custom_llm_provider="openai", - ) - return {} - else: - raise ValueError("mode not set, passed in mode: " + mode) - response = {} - - if completion is None or not hasattr(completion, "headers"): - raise Exception("invalid completion response") - - if ( - completion.headers.get("x-ratelimit-remaining-requests", None) is not None - ): # not provided for dall-e requests - response["x-ratelimit-remaining-requests"] = completion.headers[ - "x-ratelimit-remaining-requests" - ] - - if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: - response["x-ratelimit-remaining-tokens"] = completion.headers[ - "x-ratelimit-remaining-tokens" - ] - return response - class OpenAIFilesAPI(BaseLLM): """ diff --git a/litellm/main.py b/litellm/main.py index 6b0f7c026f..52a886b69c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -51,6 +51,11 @@ from litellm import ( # type: ignore get_optional_params, ) from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check +from litellm.litellm_core_utils.health_check_utils import ( + _create_health_check_response, + _filter_model_params, +) from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.mock_functions import ( mock_embedding, @@ -60,6 +65,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import ( get_content_from_model_response, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.realtime_api.main import _realtime_health_check from litellm.secret_managers.main import get_secret_str from litellm.utils import ( CustomStreamWrapper, @@ -5117,65 +5123,60 @@ def speech( ##### Health Endpoints ####################### -async def ahealth_check_chat_models( +async def ahealth_check_wildcard_models( model: str, custom_llm_provider: str, model_params: dict ) -> dict: - if "*" in model: - from litellm.litellm_core_utils.llm_request_utils import ( - pick_cheapest_chat_model_from_llm_provider, - ) - - # this is a wildcard model, we need to pick a random model from the provider - cheapest_model = pick_cheapest_chat_model_from_llm_provider( - custom_llm_provider=custom_llm_provider - ) - fallback_models: Optional[List] = None - if custom_llm_provider in litellm.models_by_provider: - models = litellm.models_by_provider[custom_llm_provider] - random.shuffle(models) # Shuffle the models list in place - fallback_models = models[ - :2 - ] # Pick the first 2 models from the shuffled list - model_params["model"] = cheapest_model - model_params["fallbacks"] = fallback_models - model_params["max_tokens"] = 1 - await acompletion(**model_params) - response: dict = {} # args like remaining ratelimit etc. - else: # default to completion calls - model_params["max_tokens"] = 1 - await acompletion(**model_params) - response = {} # args like remaining ratelimit etc. + from litellm.litellm_core_utils.llm_request_utils import ( + pick_cheapest_chat_model_from_llm_provider, + ) + # this is a wildcard model, we need to pick a random model from the provider + cheapest_model = pick_cheapest_chat_model_from_llm_provider( + custom_llm_provider=custom_llm_provider + ) + fallback_models: Optional[List] = None + if custom_llm_provider in litellm.models_by_provider: + models = litellm.models_by_provider[custom_llm_provider] + random.shuffle(models) # Shuffle the models list in place + fallback_models = models[:2] # Pick the first 2 models from the shuffled list + model_params["model"] = cheapest_model + model_params["fallbacks"] = fallback_models + model_params["max_tokens"] = 1 + await acompletion(**model_params) + response: dict = {} # args like remaining ratelimit etc. return response -async def ahealth_check( # noqa: PLR0915 +async def ahealth_check( model_params: dict, mode: Optional[ Literal[ + "chat", "completion", "embedding", + "audio_speech", + "audio_transcription", "image_generation", - "chat", "batch", "rerank", "realtime", ] - ] = None, + ] = "chat", prompt: Optional[str] = None, input: Optional[List] = None, - default_timeout: float = 6000, ): """ Support health checks for different providers. Return remaining rate limit, etc. - For azure/openai -> completion.with_raw_response - For rest -> litellm.acompletion() + Returns: + { + "x-ratelimit-remaining-requests": int, + "x-ratelimit-remaining-tokens": int, + "x-ms-region": str, + } """ - passed_in_mode: Optional[str] = None try: model: Optional[str] = model_params.get("model", None) - if model is None: raise Exception("model not set") @@ -5183,122 +5184,73 @@ async def ahealth_check( # noqa: PLR0915 mode = litellm.model_cost[model].get("mode") model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if model in litellm.model_cost and mode is None: mode = litellm.model_cost[model].get("mode") - mode = mode - passed_in_mode = mode - if mode is None: - mode = "chat" # default to chat completion calls - - if custom_llm_provider == "azure": - api_key = ( - model_params.get("api_key") - or get_secret_str("AZURE_API_KEY") - or get_secret_str("AZURE_OPENAI_API_KEY") - ) - - api_base: Optional[str] = ( - model_params.get("api_base") - or get_secret_str("AZURE_API_BASE") - or get_secret_str("AZURE_OPENAI_API_BASE") - ) - - if api_base is None: - raise ValueError( - "Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`" - ) - - api_version = ( - model_params.get("api_version") - or get_secret_str("AZURE_API_VERSION") - or get_secret_str("AZURE_OPENAI_API_VERSION") - ) - - timeout = ( - model_params.get("timeout") - or litellm.request_timeout - or default_timeout - ) - - response = await azure_chat_completions.ahealth_check( + model_params["cache"] = { + "no-cache": True + } # don't used cached responses for making health check calls + if "*" in model: + return await ahealth_check_wildcard_models( model=model, - messages=model_params.get( - "messages", None - ), # Replace with your actual messages list - api_key=api_key, - api_base=api_base, - api_version=api_version, - timeout=timeout, - mode=mode, + custom_llm_provider=custom_llm_provider, + model_params=model_params, + ) + # Map modes to their corresponding health check calls + mode_handlers = { + "chat": lambda: litellm.acompletion(**model_params), + "completion": lambda: litellm.atext_completion( + **_filter_model_params(model_params), + prompt=prompt or "test", + ), + "embedding": lambda: litellm.aembedding( + **_filter_model_params(model_params), + input=input or ["test"], + ), + "audio_speech": lambda: litellm.aspeech( + **_filter_model_params(model_params), + input=prompt or "test", + voice="alloy", + ), + "audio_transcription": lambda: litellm.atranscription( + **_filter_model_params(model_params), + file=get_audio_file_for_health_check(), + ), + "image_generation": lambda: litellm.aimage_generation( + **_filter_model_params(model_params), prompt=prompt, - input=input, - ) - elif ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - ): - api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY") - organization = model_params.get("organization") - - timeout = ( - model_params.get("timeout") - or litellm.request_timeout - or default_timeout - ) - - api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE") - - if custom_llm_provider == "text-completion-openai": - mode = "completion" - - response = await openai_chat_completions.ahealth_check( + ), + "rerank": lambda: litellm.arerank( + **_filter_model_params(model_params), + query=prompt or "", + documents=["my sample text"], + ), + "realtime": lambda: _realtime_health_check( model=model, - messages=model_params.get( - "messages", None - ), # Replace with your actual messages list - api_key=api_key, - api_base=api_base, - timeout=timeout, - mode=mode, - prompt=prompt, - input=input, - organization=organization, + custom_llm_provider=custom_llm_provider, + api_base=model_params.get("api_base", None), + api_key=model_params.get("api_key", None), + api_version=model_params.get("api_version", None), + ), + } + + if mode in mode_handlers: + _response = await mode_handlers[mode]() + # Only process headers for chat mode + _response_headers: dict = ( + getattr(_response, "_hidden_params", {}).get("headers", {}) or {} ) + return _create_health_check_response(_response_headers) else: - model_params["cache"] = { - "no-cache": True - } # don't used cached responses for making health check calls - if mode == "embedding": - model_params.pop("messages", None) - model_params["input"] = input - await litellm.aembedding(**model_params) - response = {} - elif mode == "image_generation": - model_params.pop("messages", None) - model_params["prompt"] = prompt - await litellm.aimage_generation(**model_params) - response = {} - elif mode == "rerank": - model_params.pop("messages", None) - model_params["query"] = prompt - model_params["documents"] = ["my sample text"] - await litellm.arerank(**model_params) - response = {} - else: - response = await ahealth_check_chat_models( - model=model, - custom_llm_provider=custom_llm_provider, - model_params=model_params, - ) - return response + raise Exception( + f"Mode {mode} not supported. See modes here: https://docs.litellm.ai/docs/proxy/health" + ) except Exception as e: stack_trace = traceback.format_exc() if isinstance(stack_trace, str): stack_trace = stack_trace[:1000] - if passed_in_mode is None: + if mode is None: return { "error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}" } diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index 7afd526c45..ac39a68c60 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -118,9 +118,9 @@ async def _arealtime( async def _realtime_health_check( model: str, - api_base: str, custom_llm_provider: str, api_key: Optional[str], + api_base: Optional[str] = None, api_version: Optional[str] = None, ): """ @@ -143,12 +143,14 @@ async def _realtime_health_check( url: Optional[str] = None if custom_llm_provider == "azure": url = azure_realtime._construct_url( - api_base=api_base, + api_base=api_base or "", model=model, api_version=api_version or "2024-10-01-preview", ) elif custom_llm_provider == "openai": - url = openai_realtime._construct_url(api_base=api_base, model=model) + url = openai_realtime._construct_url( + api_base=api_base or "https://api.openai.com/", model=model + ) async with websockets.connect( # type: ignore url, extra_headers={ diff --git a/tests/local_testing/test_health_check.py b/tests/local_testing/test_health_check.py index 3535a4fe94..0d43c4cc05 100644 --- a/tests/local_testing/test_health_check.py +++ b/tests/local_testing/test_health_check.py @@ -6,6 +6,7 @@ import sys import traceback import pytest +from unittest.mock import AsyncMock, patch sys.path.insert( 0, os.path.abspath("../..") @@ -35,6 +36,19 @@ async def test_azure_health_check(): # asyncio.run(test_azure_health_check()) +@pytest.mark.asyncio +async def test_text_completion_health_check(): + response = await litellm.ahealth_check( + model_params={"model": "gpt-3.5-turbo-instruct"}, + mode="completion", + prompt="What's the weather in SF?", + ) + print(f"response: {response}") + + assert "x-ratelimit-remaining-tokens" in response + return response + + @pytest.mark.asyncio async def test_azure_embedding_health_check(): response = await litellm.ahealth_check( @@ -128,7 +142,6 @@ async def test_groq_health_check(): mode=None, prompt="What's 1 + 1?", input=["test from litellm"], - default_timeout=6000, ) print(f"response: {response}") assert response == {} @@ -141,8 +154,6 @@ async def test_cohere_rerank_health_check(): response = await litellm.ahealth_check( model_params={ "model": "cohere/rerank-english-v3.0", - "query": "Hey, how's it going", - "documents": ["my sample text"], "api_key": os.getenv("COHERE_API_KEY"), }, mode="rerank", @@ -154,15 +165,52 @@ async def test_cohere_rerank_health_check(): print(response) +@pytest.mark.asyncio +async def test_audio_speech_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "openai/tts-1", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + mode="audio_speech", + prompt="Hey", + ) + + assert "error" not in response + + print(response) + + +@pytest.mark.asyncio +async def test_audio_transcription_health_check(): + litellm.set_verbose = True + response = await litellm.ahealth_check( + model_params={ + "model": "openai/whisper-1", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + mode="audio_transcription", + ) + + assert "error" not in response + + print(response) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model", ["azure/gpt-4o-realtime-preview", "openai/gpt-4o-realtime-preview"] ) -async def test_realtime_health_check(model): +async def test_async_realtime_health_check(model, mocker): """ Test Health Check with Valid models passes """ + mock_websocket = AsyncMock() + mock_connect = AsyncMock().__aenter__.return_value = mock_websocket + mocker.patch("websockets.connect", return_value=mock_connect) + + litellm.set_verbose = True model_params = { "model": model, }