diff --git a/litellm/main.py b/litellm/main.py index 4e53f99258..64049c31d1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5487,7 +5487,9 @@ async def ahealth_check( model_params["litellm_logging_obj"] = litellm_logging_obj mode_handlers = { - "chat": lambda: litellm.acompletion(**model_params), + "chat": lambda: litellm.acompletion( + **model_params, + ), "completion": lambda: litellm.atext_completion( **_filter_model_params(model_params), prompt=prompt or "test", @@ -5544,12 +5546,7 @@ async def ahealth_check( "error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}" } - error_to_return = ( - str(e) - + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models" - + "\nstack trace: " - + stack_trace - ) + error_to_return = str(e) + "\nstack trace: " + stack_trace raw_request_typed_dict = litellm_logging_obj.model_call_details.get( "raw_request_typed_dict" diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index d269bcf84a..378c94aacb 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -3,13 +3,14 @@ import copy import os import traceback from datetime import datetime, timedelta -from typing import Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Union import fastapi from fastapi import APIRouter, Depends, HTTPException, Request, Response, status import litellm from litellm._logging import verbose_proxy_logger +from litellm.constants import HEALTH_CHECK_TIMEOUT_SECONDS from litellm.proxy._types import ( AlertType, CallInfo, @@ -21,6 +22,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.health_check import ( _clean_endpoint_data, + _update_litellm_params_for_health_check, perform_health_check, run_with_timeout, ) @@ -626,9 +628,9 @@ async def test_model_connection( "realtime", ] ] = fastapi.Body("chat", description="The mode to test the model with"), - prompt: Optional[str] = fastapi.Body(None, description="Test prompt for the model"), - timeout: Optional[int] = fastapi.Body( - 30, description="Timeout in seconds for the health check" + litellm_params: Dict = fastapi.Body( + None, + description="Parameters for litellm.completion, litellm.embedding for the health check", ), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): @@ -644,10 +646,14 @@ async def test_model_connection( -H 'Authorization: Bearer sk-1234' \\ -H 'Content-Type: application/json' \\ -d '{ - "model": "openai/gpt-3.5-turbo", - "mode": "chat", - "prompt": "Hello, world!", - "timeout": 30 + "litellm_params": { + "model": "gpt-4", + "custom_llm_provider": "azure_ai", + "litellm_credential_name": null, + "api_key": "6xxxxxxx", + "api_base": "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21", + }, + "mode": "chat" }' ``` @@ -655,23 +661,26 @@ async def test_model_connection( dict: A dictionary containing the health check result with either success information or error details. """ try: - # Create basic params for the model - model_params = await request.json() - model_params.pop("mode") - - # Run the health check with timeout + # Include health_check_params if provided + litellm_params = _update_litellm_params_for_health_check( + model_info={}, + litellm_params=litellm_params, + ) + mode = mode or litellm_params.pop("mode", None) result = await run_with_timeout( litellm.ahealth_check( - model_params, + model_params=litellm_params, mode=mode, - prompt=prompt, - input=[prompt] if prompt else ["test from litellm"], + prompt="test from litellm", + input=["test from litellm"], ), - timeout, + HEALTH_CHECK_TIMEOUT_SECONDS, ) # Clean the result for display - cleaned_result = _clean_endpoint_data({**model_params, **result}, details=True) + cleaned_result = _clean_endpoint_data( + {**litellm_params, **result}, details=True + ) return { "status": "error" if "error" in result else "success",