mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)
* add mode: realtime * add _realtime_health_check * test_realtime_health_check * azure _realtime_health_check * _realtime_health_check * Realtime Models * fix code quality * delete OAI / Azure custom health check code * simplest version of ahealth check * update tests * working health check post refactor * working aspeech health check * fix realtime health checks * test_audio_transcription_health_check * use get_audio_file_for_health_check * test_text_completion_health_check * ahealth_check * simplify health check code * update ahealth_check * fix import * fix unused imports * fix ahealth_check * fix local testing * test_async_realtime_health_check
This commit is contained in:
parent
4e65722a00
commit
1e06ee3162
9 changed files with 188 additions and 373 deletions
|
@ -69,6 +69,7 @@ jobs:
|
||||||
pip install "Pillow==10.3.0"
|
pip install "Pillow==10.3.0"
|
||||||
pip install "jsonschema==4.22.0"
|
pip install "jsonschema==4.22.0"
|
||||||
pip install "pytest-xdist==3.6.1"
|
pip install "pytest-xdist==3.6.1"
|
||||||
|
pip install "websockets==10.4"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
|
BIN
litellm/litellm_core_utils/audio_utils/audio_health_check.wav
Normal file
BIN
litellm/litellm_core_utils/audio_utils/audio_health_check.wav
Normal file
Binary file not shown.
|
@ -2,6 +2,8 @@
|
||||||
Utils used for litellm.transcription() and litellm.atranscription()
|
Utils used for litellm.transcription() and litellm.atranscription()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from litellm.types.utils import FileTypes
|
from litellm.types.utils import FileTypes
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,3 +23,14 @@ def get_audio_file_name(file_obj: FileTypes) -> str:
|
||||||
return str(file_obj)
|
return str(file_obj)
|
||||||
else:
|
else:
|
||||||
return repr(file_obj)
|
return repr(file_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_file_for_health_check() -> FileTypes:
|
||||||
|
"""
|
||||||
|
Get an audio file for health check
|
||||||
|
|
||||||
|
Returns the content of `audio_health_check.wav` in the same directory as this file
|
||||||
|
"""
|
||||||
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
file_path = os.path.join(pwd, "audio_health_check.wav")
|
||||||
|
return open(file_path, "rb")
|
||||||
|
|
28
litellm/litellm_core_utils/health_check_utils.py
Normal file
28
litellm/litellm_core_utils/health_check_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
"""
|
||||||
|
Utils used for litellm.ahealth_check()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_model_params(model_params: dict) -> dict:
|
||||||
|
"""Remove 'messages' param from model params."""
|
||||||
|
return {k: v for k, v in model_params.items() if k != "messages"}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_health_check_response(response_headers: dict) -> dict:
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if (
|
||||||
|
response_headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = response_headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if response_headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = response_headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
if response_headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = response_headers["x-ms-region"]
|
||||||
|
return response
|
|
@ -1491,132 +1491,3 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def ahealth_check(
|
|
||||||
self,
|
|
||||||
model: Optional[str],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: str,
|
|
||||||
api_version: Optional[str],
|
|
||||||
timeout: float,
|
|
||||||
mode: str,
|
|
||||||
messages: Optional[list] = None,
|
|
||||||
input: Optional[list] = None,
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
) -> dict:
|
|
||||||
client_session = (
|
|
||||||
litellm.aclient_session
|
|
||||||
or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
|
|
||||||
) # handle dall-e-2 calls
|
|
||||||
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
|
||||||
## build base url - assume api base includes resource name
|
|
||||||
if not api_base.endswith("/"):
|
|
||||||
api_base += "/"
|
|
||||||
api_base += f"{model}"
|
|
||||||
client = AsyncAzureOpenAI(
|
|
||||||
base_url=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=timeout,
|
|
||||||
http_client=client_session,
|
|
||||||
)
|
|
||||||
model = None
|
|
||||||
# cloudflare ai gateway, needs model=None
|
|
||||||
else:
|
|
||||||
client = AsyncAzureOpenAI(
|
|
||||||
api_version=api_version,
|
|
||||||
azure_endpoint=api_base,
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=timeout,
|
|
||||||
http_client=client_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# only run this check if it's not cloudflare ai gateway
|
|
||||||
if model is None and mode != "image_generation":
|
|
||||||
raise Exception("model is not set")
|
|
||||||
|
|
||||||
completion = None
|
|
||||||
|
|
||||||
if mode == "completion":
|
|
||||||
completion = await client.completions.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "chat":
|
|
||||||
if messages is None:
|
|
||||||
raise Exception("messages is not set")
|
|
||||||
completion = await client.chat.completions.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
messages=messages, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "embedding":
|
|
||||||
if input is None:
|
|
||||||
raise Exception("input is not set")
|
|
||||||
completion = await client.embeddings.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
input=input, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "image_generation":
|
|
||||||
if prompt is None:
|
|
||||||
raise Exception("prompt is not set")
|
|
||||||
completion = await client.images.with_raw_response.generate(
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "audio_transcription":
|
|
||||||
# Get the current directory of the file being run
|
|
||||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
file_path = os.path.join(
|
|
||||||
pwd, "../../../tests/gettysburg.wav"
|
|
||||||
) # proxy address
|
|
||||||
audio_file = open(file_path, "rb")
|
|
||||||
completion = await client.audio.transcriptions.with_raw_response.create(
|
|
||||||
file=audio_file,
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "audio_speech":
|
|
||||||
# Get the current directory of the file being run
|
|
||||||
completion = await client.audio.speech.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
input=prompt, # type: ignore
|
|
||||||
voice="alloy",
|
|
||||||
)
|
|
||||||
elif mode == "batch":
|
|
||||||
completion = await client.batches.with_raw_response.list(limit=1) # type: ignore
|
|
||||||
elif mode == "realtime":
|
|
||||||
from litellm.realtime_api.main import _realtime_health_check
|
|
||||||
|
|
||||||
# create a websocket connection
|
|
||||||
await _realtime_health_check(
|
|
||||||
model=model or "",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
custom_llm_provider="azure",
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
raise Exception("mode not set")
|
|
||||||
response = {}
|
|
||||||
|
|
||||||
if completion is None or not hasattr(completion, "headers"):
|
|
||||||
raise Exception("invalid completion response")
|
|
||||||
|
|
||||||
if (
|
|
||||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
|
||||||
): # not provided for dall-e requests
|
|
||||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
|
||||||
"x-ratelimit-remaining-requests"
|
|
||||||
]
|
|
||||||
|
|
||||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
|
||||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
|
||||||
"x-ratelimit-remaining-tokens"
|
|
||||||
]
|
|
||||||
|
|
||||||
if completion.headers.get("x-ms-region", None) is not None:
|
|
||||||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
|
||||||
import types
|
import types
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -1306,105 +1305,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return HttpxBinaryResponseContent(response=response.response)
|
return HttpxBinaryResponseContent(response=response.response)
|
||||||
|
|
||||||
async def ahealth_check(
|
|
||||||
self,
|
|
||||||
model: Optional[str],
|
|
||||||
api_key: Optional[str],
|
|
||||||
timeout: float,
|
|
||||||
mode: str,
|
|
||||||
messages: Optional[list] = None,
|
|
||||||
input: Optional[list] = None,
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
):
|
|
||||||
client = AsyncOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=timeout,
|
|
||||||
organization=organization,
|
|
||||||
base_url=api_base,
|
|
||||||
)
|
|
||||||
if model is None and mode != "image_generation":
|
|
||||||
raise Exception("model is not set")
|
|
||||||
|
|
||||||
completion = None
|
|
||||||
|
|
||||||
if mode == "completion":
|
|
||||||
completion = await client.completions.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "chat":
|
|
||||||
if messages is None:
|
|
||||||
raise Exception("messages is not set")
|
|
||||||
completion = await client.chat.completions.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
messages=messages, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "embedding":
|
|
||||||
if input is None:
|
|
||||||
raise Exception("input is not set")
|
|
||||||
completion = await client.embeddings.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
input=input, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "image_generation":
|
|
||||||
if prompt is None:
|
|
||||||
raise Exception("prompt is not set")
|
|
||||||
completion = await client.images.with_raw_response.generate(
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "audio_transcription":
|
|
||||||
# Get the current directory of the file being run
|
|
||||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
file_path = os.path.join(
|
|
||||||
pwd, "../../../tests/gettysburg.wav"
|
|
||||||
) # proxy address
|
|
||||||
audio_file = open(file_path, "rb")
|
|
||||||
completion = await client.audio.transcriptions.with_raw_response.create(
|
|
||||||
file=audio_file,
|
|
||||||
model=model, # type: ignore
|
|
||||||
prompt=prompt, # type: ignore
|
|
||||||
)
|
|
||||||
elif mode == "audio_speech":
|
|
||||||
# Get the current directory of the file being run
|
|
||||||
completion = await client.audio.speech.with_raw_response.create(
|
|
||||||
model=model, # type: ignore
|
|
||||||
input=prompt, # type: ignore
|
|
||||||
voice="alloy",
|
|
||||||
)
|
|
||||||
elif mode == "realtime":
|
|
||||||
from litellm.realtime_api.main import _realtime_health_check
|
|
||||||
|
|
||||||
# create a websocket connection
|
|
||||||
await _realtime_health_check(
|
|
||||||
model=model or "",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base or "https://api.openai.com/",
|
|
||||||
custom_llm_provider="openai",
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
raise ValueError("mode not set, passed in mode: " + mode)
|
|
||||||
response = {}
|
|
||||||
|
|
||||||
if completion is None or not hasattr(completion, "headers"):
|
|
||||||
raise Exception("invalid completion response")
|
|
||||||
|
|
||||||
if (
|
|
||||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
|
||||||
): # not provided for dall-e requests
|
|
||||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
|
||||||
"x-ratelimit-remaining-requests"
|
|
||||||
]
|
|
||||||
|
|
||||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
|
||||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
|
||||||
"x-ratelimit-remaining-tokens"
|
|
||||||
]
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFilesAPI(BaseLLM):
|
class OpenAIFilesAPI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
|
|
226
litellm/main.py
226
litellm/main.py
|
@ -51,6 +51,11 @@ from litellm import ( # type: ignore
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
|
||||||
|
from litellm.litellm_core_utils.health_check_utils import (
|
||||||
|
_create_health_check_response,
|
||||||
|
_filter_model_params,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.litellm_core_utils.mock_functions import (
|
from litellm.litellm_core_utils.mock_functions import (
|
||||||
mock_embedding,
|
mock_embedding,
|
||||||
|
@ -60,6 +65,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
get_content_from_model_response,
|
get_content_from_model_response,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.realtime_api.main import _realtime_health_check
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -5117,65 +5123,60 @@ def speech(
|
||||||
##### Health Endpoints #######################
|
##### Health Endpoints #######################
|
||||||
|
|
||||||
|
|
||||||
async def ahealth_check_chat_models(
|
async def ahealth_check_wildcard_models(
|
||||||
model: str, custom_llm_provider: str, model_params: dict
|
model: str, custom_llm_provider: str, model_params: dict
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if "*" in model:
|
from litellm.litellm_core_utils.llm_request_utils import (
|
||||||
from litellm.litellm_core_utils.llm_request_utils import (
|
pick_cheapest_chat_model_from_llm_provider,
|
||||||
pick_cheapest_chat_model_from_llm_provider,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# this is a wildcard model, we need to pick a random model from the provider
|
|
||||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
|
||||||
custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
fallback_models: Optional[List] = None
|
|
||||||
if custom_llm_provider in litellm.models_by_provider:
|
|
||||||
models = litellm.models_by_provider[custom_llm_provider]
|
|
||||||
random.shuffle(models) # Shuffle the models list in place
|
|
||||||
fallback_models = models[
|
|
||||||
:2
|
|
||||||
] # Pick the first 2 models from the shuffled list
|
|
||||||
model_params["model"] = cheapest_model
|
|
||||||
model_params["fallbacks"] = fallback_models
|
|
||||||
model_params["max_tokens"] = 1
|
|
||||||
await acompletion(**model_params)
|
|
||||||
response: dict = {} # args like remaining ratelimit etc.
|
|
||||||
else: # default to completion calls
|
|
||||||
model_params["max_tokens"] = 1
|
|
||||||
await acompletion(**model_params)
|
|
||||||
response = {} # args like remaining ratelimit etc.
|
|
||||||
|
|
||||||
|
# this is a wildcard model, we need to pick a random model from the provider
|
||||||
|
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||||
|
custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
fallback_models: Optional[List] = None
|
||||||
|
if custom_llm_provider in litellm.models_by_provider:
|
||||||
|
models = litellm.models_by_provider[custom_llm_provider]
|
||||||
|
random.shuffle(models) # Shuffle the models list in place
|
||||||
|
fallback_models = models[:2] # Pick the first 2 models from the shuffled list
|
||||||
|
model_params["model"] = cheapest_model
|
||||||
|
model_params["fallbacks"] = fallback_models
|
||||||
|
model_params["max_tokens"] = 1
|
||||||
|
await acompletion(**model_params)
|
||||||
|
response: dict = {} # args like remaining ratelimit etc.
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def ahealth_check( # noqa: PLR0915
|
async def ahealth_check(
|
||||||
model_params: dict,
|
model_params: dict,
|
||||||
mode: Optional[
|
mode: Optional[
|
||||||
Literal[
|
Literal[
|
||||||
|
"chat",
|
||||||
"completion",
|
"completion",
|
||||||
"embedding",
|
"embedding",
|
||||||
|
"audio_speech",
|
||||||
|
"audio_transcription",
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"chat",
|
|
||||||
"batch",
|
"batch",
|
||||||
"rerank",
|
"rerank",
|
||||||
"realtime",
|
"realtime",
|
||||||
]
|
]
|
||||||
] = None,
|
] = "chat",
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
input: Optional[List] = None,
|
input: Optional[List] = None,
|
||||||
default_timeout: float = 6000,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Support health checks for different providers. Return remaining rate limit, etc.
|
Support health checks for different providers. Return remaining rate limit, etc.
|
||||||
|
|
||||||
For azure/openai -> completion.with_raw_response
|
Returns:
|
||||||
For rest -> litellm.acompletion()
|
{
|
||||||
|
"x-ratelimit-remaining-requests": int,
|
||||||
|
"x-ratelimit-remaining-tokens": int,
|
||||||
|
"x-ms-region": str,
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
passed_in_mode: Optional[str] = None
|
|
||||||
try:
|
try:
|
||||||
model: Optional[str] = model_params.get("model", None)
|
model: Optional[str] = model_params.get("model", None)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
raise Exception("model not set")
|
raise Exception("model not set")
|
||||||
|
|
||||||
|
@ -5183,122 +5184,73 @@ async def ahealth_check( # noqa: PLR0915
|
||||||
mode = litellm.model_cost[model].get("mode")
|
mode = litellm.model_cost[model].get("mode")
|
||||||
|
|
||||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
|
|
||||||
if model in litellm.model_cost and mode is None:
|
if model in litellm.model_cost and mode is None:
|
||||||
mode = litellm.model_cost[model].get("mode")
|
mode = litellm.model_cost[model].get("mode")
|
||||||
|
|
||||||
mode = mode
|
model_params["cache"] = {
|
||||||
passed_in_mode = mode
|
"no-cache": True
|
||||||
if mode is None:
|
} # don't used cached responses for making health check calls
|
||||||
mode = "chat" # default to chat completion calls
|
if "*" in model:
|
||||||
|
return await ahealth_check_wildcard_models(
|
||||||
if custom_llm_provider == "azure":
|
|
||||||
api_key = (
|
|
||||||
model_params.get("api_key")
|
|
||||||
or get_secret_str("AZURE_API_KEY")
|
|
||||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
|
||||||
)
|
|
||||||
|
|
||||||
api_base: Optional[str] = (
|
|
||||||
model_params.get("api_base")
|
|
||||||
or get_secret_str("AZURE_API_BASE")
|
|
||||||
or get_secret_str("AZURE_OPENAI_API_BASE")
|
|
||||||
)
|
|
||||||
|
|
||||||
if api_base is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_version = (
|
|
||||||
model_params.get("api_version")
|
|
||||||
or get_secret_str("AZURE_API_VERSION")
|
|
||||||
or get_secret_str("AZURE_OPENAI_API_VERSION")
|
|
||||||
)
|
|
||||||
|
|
||||||
timeout = (
|
|
||||||
model_params.get("timeout")
|
|
||||||
or litellm.request_timeout
|
|
||||||
or default_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await azure_chat_completions.ahealth_check(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=model_params.get(
|
custom_llm_provider=custom_llm_provider,
|
||||||
"messages", None
|
model_params=model_params,
|
||||||
), # Replace with your actual messages list
|
)
|
||||||
api_key=api_key,
|
# Map modes to their corresponding health check calls
|
||||||
api_base=api_base,
|
mode_handlers = {
|
||||||
api_version=api_version,
|
"chat": lambda: litellm.acompletion(**model_params),
|
||||||
timeout=timeout,
|
"completion": lambda: litellm.atext_completion(
|
||||||
mode=mode,
|
**_filter_model_params(model_params),
|
||||||
|
prompt=prompt or "test",
|
||||||
|
),
|
||||||
|
"embedding": lambda: litellm.aembedding(
|
||||||
|
**_filter_model_params(model_params),
|
||||||
|
input=input or ["test"],
|
||||||
|
),
|
||||||
|
"audio_speech": lambda: litellm.aspeech(
|
||||||
|
**_filter_model_params(model_params),
|
||||||
|
input=prompt or "test",
|
||||||
|
voice="alloy",
|
||||||
|
),
|
||||||
|
"audio_transcription": lambda: litellm.atranscription(
|
||||||
|
**_filter_model_params(model_params),
|
||||||
|
file=get_audio_file_for_health_check(),
|
||||||
|
),
|
||||||
|
"image_generation": lambda: litellm.aimage_generation(
|
||||||
|
**_filter_model_params(model_params),
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input=input,
|
),
|
||||||
)
|
"rerank": lambda: litellm.arerank(
|
||||||
elif (
|
**_filter_model_params(model_params),
|
||||||
custom_llm_provider == "openai"
|
query=prompt or "",
|
||||||
or custom_llm_provider == "text-completion-openai"
|
documents=["my sample text"],
|
||||||
):
|
),
|
||||||
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
|
"realtime": lambda: _realtime_health_check(
|
||||||
organization = model_params.get("organization")
|
|
||||||
|
|
||||||
timeout = (
|
|
||||||
model_params.get("timeout")
|
|
||||||
or litellm.request_timeout
|
|
||||||
or default_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
|
|
||||||
|
|
||||||
if custom_llm_provider == "text-completion-openai":
|
|
||||||
mode = "completion"
|
|
||||||
|
|
||||||
response = await openai_chat_completions.ahealth_check(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=model_params.get(
|
custom_llm_provider=custom_llm_provider,
|
||||||
"messages", None
|
api_base=model_params.get("api_base", None),
|
||||||
), # Replace with your actual messages list
|
api_key=model_params.get("api_key", None),
|
||||||
api_key=api_key,
|
api_version=model_params.get("api_version", None),
|
||||||
api_base=api_base,
|
),
|
||||||
timeout=timeout,
|
}
|
||||||
mode=mode,
|
|
||||||
prompt=prompt,
|
if mode in mode_handlers:
|
||||||
input=input,
|
_response = await mode_handlers[mode]()
|
||||||
organization=organization,
|
# Only process headers for chat mode
|
||||||
|
_response_headers: dict = (
|
||||||
|
getattr(_response, "_hidden_params", {}).get("headers", {}) or {}
|
||||||
)
|
)
|
||||||
|
return _create_health_check_response(_response_headers)
|
||||||
else:
|
else:
|
||||||
model_params["cache"] = {
|
raise Exception(
|
||||||
"no-cache": True
|
f"Mode {mode} not supported. See modes here: https://docs.litellm.ai/docs/proxy/health"
|
||||||
} # don't used cached responses for making health check calls
|
)
|
||||||
if mode == "embedding":
|
|
||||||
model_params.pop("messages", None)
|
|
||||||
model_params["input"] = input
|
|
||||||
await litellm.aembedding(**model_params)
|
|
||||||
response = {}
|
|
||||||
elif mode == "image_generation":
|
|
||||||
model_params.pop("messages", None)
|
|
||||||
model_params["prompt"] = prompt
|
|
||||||
await litellm.aimage_generation(**model_params)
|
|
||||||
response = {}
|
|
||||||
elif mode == "rerank":
|
|
||||||
model_params.pop("messages", None)
|
|
||||||
model_params["query"] = prompt
|
|
||||||
model_params["documents"] = ["my sample text"]
|
|
||||||
await litellm.arerank(**model_params)
|
|
||||||
response = {}
|
|
||||||
else:
|
|
||||||
response = await ahealth_check_chat_models(
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
model_params=model_params,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stack_trace = traceback.format_exc()
|
stack_trace = traceback.format_exc()
|
||||||
if isinstance(stack_trace, str):
|
if isinstance(stack_trace, str):
|
||||||
stack_trace = stack_trace[:1000]
|
stack_trace = stack_trace[:1000]
|
||||||
|
|
||||||
if passed_in_mode is None:
|
if mode is None:
|
||||||
return {
|
return {
|
||||||
"error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}"
|
"error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}"
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,9 +118,9 @@ async def _arealtime(
|
||||||
|
|
||||||
async def _realtime_health_check(
|
async def _realtime_health_check(
|
||||||
model: str,
|
model: str,
|
||||||
api_base: str,
|
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -143,12 +143,14 @@ async def _realtime_health_check(
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
url = azure_realtime._construct_url(
|
url = azure_realtime._construct_url(
|
||||||
api_base=api_base,
|
api_base=api_base or "",
|
||||||
model=model,
|
model=model,
|
||||||
api_version=api_version or "2024-10-01-preview",
|
api_version=api_version or "2024-10-01-preview",
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
url = openai_realtime._construct_url(api_base=api_base, model=model)
|
url = openai_realtime._construct_url(
|
||||||
|
api_base=api_base or "https://api.openai.com/", model=model
|
||||||
|
)
|
||||||
async with websockets.connect( # type: ignore
|
async with websockets.connect( # type: ignore
|
||||||
url,
|
url,
|
||||||
extra_headers={
|
extra_headers={
|
||||||
|
|
|
@ -6,6 +6,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -35,6 +36,19 @@ async def test_azure_health_check():
|
||||||
# asyncio.run(test_azure_health_check())
|
# asyncio.run(test_azure_health_check())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={"model": "gpt-3.5-turbo-instruct"},
|
||||||
|
mode="completion",
|
||||||
|
prompt="What's the weather in SF?",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert "x-ratelimit-remaining-tokens" in response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_embedding_health_check():
|
async def test_azure_embedding_health_check():
|
||||||
response = await litellm.ahealth_check(
|
response = await litellm.ahealth_check(
|
||||||
|
@ -128,7 +142,6 @@ async def test_groq_health_check():
|
||||||
mode=None,
|
mode=None,
|
||||||
prompt="What's 1 + 1?",
|
prompt="What's 1 + 1?",
|
||||||
input=["test from litellm"],
|
input=["test from litellm"],
|
||||||
default_timeout=6000,
|
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert response == {}
|
assert response == {}
|
||||||
|
@ -141,8 +154,6 @@ async def test_cohere_rerank_health_check():
|
||||||
response = await litellm.ahealth_check(
|
response = await litellm.ahealth_check(
|
||||||
model_params={
|
model_params={
|
||||||
"model": "cohere/rerank-english-v3.0",
|
"model": "cohere/rerank-english-v3.0",
|
||||||
"query": "Hey, how's it going",
|
|
||||||
"documents": ["my sample text"],
|
|
||||||
"api_key": os.getenv("COHERE_API_KEY"),
|
"api_key": os.getenv("COHERE_API_KEY"),
|
||||||
},
|
},
|
||||||
mode="rerank",
|
mode="rerank",
|
||||||
|
@ -154,15 +165,52 @@ async def test_cohere_rerank_health_check():
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_speech_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "openai/tts-1",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
mode="audio_speech",
|
||||||
|
prompt="Hey",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "error" not in response
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_transcription_health_check():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "openai/whisper-1",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
mode="audio_transcription",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "error" not in response
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model", ["azure/gpt-4o-realtime-preview", "openai/gpt-4o-realtime-preview"]
|
"model", ["azure/gpt-4o-realtime-preview", "openai/gpt-4o-realtime-preview"]
|
||||||
)
|
)
|
||||||
async def test_realtime_health_check(model):
|
async def test_async_realtime_health_check(model, mocker):
|
||||||
"""
|
"""
|
||||||
Test Health Check with Valid models passes
|
Test Health Check with Valid models passes
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
mock_websocket = AsyncMock()
|
||||||
|
mock_connect = AsyncMock().__aenter__.return_value = mock_websocket
|
||||||
|
mocker.patch("websockets.connect", return_value=mock_connect)
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
model_params = {
|
model_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue