(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)

* add mode: realtime

* add _realtime_health_check

* test_realtime_health_check

* azure _realtime_health_check

* _realtime_health_check

* Realtime Models

* fix code quality

* delete OAI / Azure custom health check code

* simplest version of ahealth check

* update tests

* working health check post refactor

* working aspeech health check

* fix realtime health checks

* test_audio_transcription_health_check

* use get_audio_file_for_health_check

* test_text_completion_health_check

* ahealth_check

* simplify health check code

* update ahealth_check

* fix import

* fix unused imports

* fix ahealth_check

* fix local testing

* test_async_realtime_health_check
This commit is contained in:
Ishaan Jaff 2024-12-28 18:38:54 -08:00 committed by GitHub
parent 4e65722a00
commit 1e06ee3162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 188 additions and 373 deletions

View file

@ -51,6 +51,11 @@ from litellm import ( # type: ignore
get_optional_params,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
from litellm.litellm_core_utils.health_check_utils import (
_create_health_check_response,
_filter_model_params,
)
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.mock_functions import (
mock_embedding,
@ -60,6 +65,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
@ -5117,65 +5123,60 @@ def speech(
##### Health Endpoints #######################
async def ahealth_check_chat_models(
async def ahealth_check_wildcard_models(
model: str, custom_llm_provider: str, model_params: dict
) -> dict:
if "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
fallback_models: Optional[List] = None
if custom_llm_provider in litellm.models_by_provider:
models = litellm.models_by_provider[custom_llm_provider]
random.shuffle(models) # Shuffle the models list in place
fallback_models = models[
:2
] # Pick the first 2 models from the shuffled list
model_params["model"] = cheapest_model
model_params["fallbacks"] = fallback_models
model_params["max_tokens"] = 1
await acompletion(**model_params)
response: dict = {} # args like remaining ratelimit etc.
else: # default to completion calls
model_params["max_tokens"] = 1
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
fallback_models: Optional[List] = None
if custom_llm_provider in litellm.models_by_provider:
models = litellm.models_by_provider[custom_llm_provider]
random.shuffle(models) # Shuffle the models list in place
fallback_models = models[:2] # Pick the first 2 models from the shuffled list
model_params["model"] = cheapest_model
model_params["fallbacks"] = fallback_models
model_params["max_tokens"] = 1
await acompletion(**model_params)
response: dict = {} # args like remaining ratelimit etc.
return response
async def ahealth_check( # noqa: PLR0915
async def ahealth_check(
model_params: dict,
mode: Optional[
Literal[
"chat",
"completion",
"embedding",
"audio_speech",
"audio_transcription",
"image_generation",
"chat",
"batch",
"rerank",
"realtime",
]
] = None,
] = "chat",
prompt: Optional[str] = None,
input: Optional[List] = None,
default_timeout: float = 6000,
):
"""
Support health checks for different providers. Return remaining rate limit, etc.
For azure/openai -> completion.with_raw_response
For rest -> litellm.acompletion()
Returns:
{
"x-ratelimit-remaining-requests": int,
"x-ratelimit-remaining-tokens": int,
"x-ms-region": str,
}
"""
passed_in_mode: Optional[str] = None
try:
model: Optional[str] = model_params.get("model", None)
if model is None:
raise Exception("model not set")
@ -5183,122 +5184,73 @@ async def ahealth_check( # noqa: PLR0915
mode = litellm.model_cost[model].get("mode")
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
if model in litellm.model_cost and mode is None:
mode = litellm.model_cost[model].get("mode")
mode = mode
passed_in_mode = mode
if mode is None:
mode = "chat" # default to chat completion calls
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret_str("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
)
api_base: Optional[str] = (
model_params.get("api_base")
or get_secret_str("AZURE_API_BASE")
or get_secret_str("AZURE_OPENAI_API_BASE")
)
if api_base is None:
raise ValueError(
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
)
api_version = (
model_params.get("api_version")
or get_secret_str("AZURE_API_VERSION")
or get_secret_str("AZURE_OPENAI_API_VERSION")
)
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await azure_chat_completions.ahealth_check(
model_params["cache"] = {
"no-cache": True
} # don't used cached responses for making health check calls
if "*" in model:
return await ahealth_check_wildcard_models(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
mode=mode,
custom_llm_provider=custom_llm_provider,
model_params=model_params,
)
# Map modes to their corresponding health check calls
mode_handlers = {
"chat": lambda: litellm.acompletion(**model_params),
"completion": lambda: litellm.atext_completion(
**_filter_model_params(model_params),
prompt=prompt or "test",
),
"embedding": lambda: litellm.aembedding(
**_filter_model_params(model_params),
input=input or ["test"],
),
"audio_speech": lambda: litellm.aspeech(
**_filter_model_params(model_params),
input=prompt or "test",
voice="alloy",
),
"audio_transcription": lambda: litellm.atranscription(
**_filter_model_params(model_params),
file=get_audio_file_for_health_check(),
),
"image_generation": lambda: litellm.aimage_generation(
**_filter_model_params(model_params),
prompt=prompt,
input=input,
)
elif (
custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai":
mode = "completion"
response = await openai_chat_completions.ahealth_check(
),
"rerank": lambda: litellm.arerank(
**_filter_model_params(model_params),
query=prompt or "",
documents=["my sample text"],
),
"realtime": lambda: _realtime_health_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
organization=organization,
custom_llm_provider=custom_llm_provider,
api_base=model_params.get("api_base", None),
api_key=model_params.get("api_key", None),
api_version=model_params.get("api_version", None),
),
}
if mode in mode_handlers:
_response = await mode_handlers[mode]()
# Only process headers for chat mode
_response_headers: dict = (
getattr(_response, "_hidden_params", {}).get("headers", {}) or {}
)
return _create_health_check_response(_response_headers)
else:
model_params["cache"] = {
"no-cache": True
} # don't used cached responses for making health check calls
if mode == "embedding":
model_params.pop("messages", None)
model_params["input"] = input
await litellm.aembedding(**model_params)
response = {}
elif mode == "image_generation":
model_params.pop("messages", None)
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
elif mode == "rerank":
model_params.pop("messages", None)
model_params["query"] = prompt
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
else:
response = await ahealth_check_chat_models(
model=model,
custom_llm_provider=custom_llm_provider,
model_params=model_params,
)
return response
raise Exception(
f"Mode {mode} not supported. See modes here: https://docs.litellm.ai/docs/proxy/health"
)
except Exception as e:
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
if passed_in_mode is None:
if mode is None:
return {
"error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}"
}