mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)
* add mode: realtime * add _realtime_health_check * test_realtime_health_check * azure _realtime_health_check * _realtime_health_check * Realtime Models * fix code quality * delete OAI / Azure custom health check code * simplest version of ahealth check * update tests * working health check post refactor * working aspeech health check * fix realtime health checks * test_audio_transcription_health_check * use get_audio_file_for_health_check * test_text_completion_health_check * ahealth_check * simplify health check code * update ahealth_check * fix import * fix unused imports * fix ahealth_check * fix local testing * test_async_realtime_health_check
This commit is contained in:
parent
4e65722a00
commit
1e06ee3162
9 changed files with 188 additions and 373 deletions
226
litellm/main.py
226
litellm/main.py
|
@ -51,6 +51,11 @@ from litellm import ( # type: ignore
|
|||
get_optional_params,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
|
||||
from litellm.litellm_core_utils.health_check_utils import (
|
||||
_create_health_check_response,
|
||||
_filter_model_params,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.mock_functions import (
|
||||
mock_embedding,
|
||||
|
@ -60,6 +65,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
get_content_from_model_response,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.realtime_api.main import _realtime_health_check
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
|
@ -5117,65 +5123,60 @@ def speech(
|
|||
##### Health Endpoints #######################
|
||||
|
||||
|
||||
async def ahealth_check_chat_models(
|
||||
async def ahealth_check_wildcard_models(
|
||||
model: str, custom_llm_provider: str, model_params: dict
|
||||
) -> dict:
|
||||
if "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
fallback_models: Optional[List] = None
|
||||
if custom_llm_provider in litellm.models_by_provider:
|
||||
models = litellm.models_by_provider[custom_llm_provider]
|
||||
random.shuffle(models) # Shuffle the models list in place
|
||||
fallback_models = models[
|
||||
:2
|
||||
] # Pick the first 2 models from the shuffled list
|
||||
model_params["model"] = cheapest_model
|
||||
model_params["fallbacks"] = fallback_models
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response: dict = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
fallback_models: Optional[List] = None
|
||||
if custom_llm_provider in litellm.models_by_provider:
|
||||
models = litellm.models_by_provider[custom_llm_provider]
|
||||
random.shuffle(models) # Shuffle the models list in place
|
||||
fallback_models = models[:2] # Pick the first 2 models from the shuffled list
|
||||
model_params["model"] = cheapest_model
|
||||
model_params["fallbacks"] = fallback_models
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response: dict = {} # args like remaining ratelimit etc.
|
||||
return response
|
||||
|
||||
|
||||
async def ahealth_check( # noqa: PLR0915
|
||||
async def ahealth_check(
|
||||
model_params: dict,
|
||||
mode: Optional[
|
||||
Literal[
|
||||
"chat",
|
||||
"completion",
|
||||
"embedding",
|
||||
"audio_speech",
|
||||
"audio_transcription",
|
||||
"image_generation",
|
||||
"chat",
|
||||
"batch",
|
||||
"rerank",
|
||||
"realtime",
|
||||
]
|
||||
] = None,
|
||||
] = "chat",
|
||||
prompt: Optional[str] = None,
|
||||
input: Optional[List] = None,
|
||||
default_timeout: float = 6000,
|
||||
):
|
||||
"""
|
||||
Support health checks for different providers. Return remaining rate limit, etc.
|
||||
|
||||
For azure/openai -> completion.with_raw_response
|
||||
For rest -> litellm.acompletion()
|
||||
Returns:
|
||||
{
|
||||
"x-ratelimit-remaining-requests": int,
|
||||
"x-ratelimit-remaining-tokens": int,
|
||||
"x-ms-region": str,
|
||||
}
|
||||
"""
|
||||
passed_in_mode: Optional[str] = None
|
||||
try:
|
||||
model: Optional[str] = model_params.get("model", None)
|
||||
|
||||
if model is None:
|
||||
raise Exception("model not set")
|
||||
|
||||
|
@ -5183,122 +5184,73 @@ async def ahealth_check( # noqa: PLR0915
|
|||
mode = litellm.model_cost[model].get("mode")
|
||||
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
|
||||
if model in litellm.model_cost and mode is None:
|
||||
mode = litellm.model_cost[model].get("mode")
|
||||
|
||||
mode = mode
|
||||
passed_in_mode = mode
|
||||
if mode is None:
|
||||
mode = "chat" # default to chat completion calls
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
api_key = (
|
||||
model_params.get("api_key")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
api_base: Optional[str] = (
|
||||
model_params.get("api_base")
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
or get_secret_str("AZURE_OPENAI_API_BASE")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
|
||||
)
|
||||
|
||||
api_version = (
|
||||
model_params.get("api_version")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
or litellm.request_timeout
|
||||
or default_timeout
|
||||
)
|
||||
|
||||
response = await azure_chat_completions.ahealth_check(
|
||||
model_params["cache"] = {
|
||||
"no-cache": True
|
||||
} # don't used cached responses for making health check calls
|
||||
if "*" in model:
|
||||
return await ahealth_check_wildcard_models(
|
||||
model=model,
|
||||
messages=model_params.get(
|
||||
"messages", None
|
||||
), # Replace with your actual messages list
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
mode=mode,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_params=model_params,
|
||||
)
|
||||
# Map modes to their corresponding health check calls
|
||||
mode_handlers = {
|
||||
"chat": lambda: litellm.acompletion(**model_params),
|
||||
"completion": lambda: litellm.atext_completion(
|
||||
**_filter_model_params(model_params),
|
||||
prompt=prompt or "test",
|
||||
),
|
||||
"embedding": lambda: litellm.aembedding(
|
||||
**_filter_model_params(model_params),
|
||||
input=input or ["test"],
|
||||
),
|
||||
"audio_speech": lambda: litellm.aspeech(
|
||||
**_filter_model_params(model_params),
|
||||
input=prompt or "test",
|
||||
voice="alloy",
|
||||
),
|
||||
"audio_transcription": lambda: litellm.atranscription(
|
||||
**_filter_model_params(model_params),
|
||||
file=get_audio_file_for_health_check(),
|
||||
),
|
||||
"image_generation": lambda: litellm.aimage_generation(
|
||||
**_filter_model_params(model_params),
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
|
||||
organization = model_params.get("organization")
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
or litellm.request_timeout
|
||||
or default_timeout
|
||||
)
|
||||
|
||||
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
|
||||
|
||||
if custom_llm_provider == "text-completion-openai":
|
||||
mode = "completion"
|
||||
|
||||
response = await openai_chat_completions.ahealth_check(
|
||||
),
|
||||
"rerank": lambda: litellm.arerank(
|
||||
**_filter_model_params(model_params),
|
||||
query=prompt or "",
|
||||
documents=["my sample text"],
|
||||
),
|
||||
"realtime": lambda: _realtime_health_check(
|
||||
model=model,
|
||||
messages=model_params.get(
|
||||
"messages", None
|
||||
), # Replace with your actual messages list
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=model_params.get("api_base", None),
|
||||
api_key=model_params.get("api_key", None),
|
||||
api_version=model_params.get("api_version", None),
|
||||
),
|
||||
}
|
||||
|
||||
if mode in mode_handlers:
|
||||
_response = await mode_handlers[mode]()
|
||||
# Only process headers for chat mode
|
||||
_response_headers: dict = (
|
||||
getattr(_response, "_hidden_params", {}).get("headers", {}) or {}
|
||||
)
|
||||
return _create_health_check_response(_response_headers)
|
||||
else:
|
||||
model_params["cache"] = {
|
||||
"no-cache": True
|
||||
} # don't used cached responses for making health check calls
|
||||
if mode == "embedding":
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = input
|
||||
await litellm.aembedding(**model_params)
|
||||
response = {}
|
||||
elif mode == "image_generation":
|
||||
model_params.pop("messages", None)
|
||||
model_params["prompt"] = prompt
|
||||
await litellm.aimage_generation(**model_params)
|
||||
response = {}
|
||||
elif mode == "rerank":
|
||||
model_params.pop("messages", None)
|
||||
model_params["query"] = prompt
|
||||
model_params["documents"] = ["my sample text"]
|
||||
await litellm.arerank(**model_params)
|
||||
response = {}
|
||||
else:
|
||||
response = await ahealth_check_chat_models(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_params=model_params,
|
||||
)
|
||||
return response
|
||||
raise Exception(
|
||||
f"Mode {mode} not supported. See modes here: https://docs.litellm.ai/docs/proxy/health"
|
||||
)
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
if isinstance(stack_trace, str):
|
||||
stack_trace = stack_trace[:1000]
|
||||
|
||||
if passed_in_mode is None:
|
||||
if mode is None:
|
||||
return {
|
||||
"error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}"
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue