(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)

* add mode: realtime

* add _realtime_health_check

* test_realtime_health_check

* azure _realtime_health_check

* _realtime_health_check

* Realtime Models

* fix code quality

* delete OAI / Azure custom health check code

* simplest version of ahealth check

* update tests

* working health check post refactor

* working aspeech health check

* fix realtime health checks

* test_audio_transcription_health_check

* use get_audio_file_for_health_check

* test_text_completion_health_check

* ahealth_check

* simplify health check code

* update ahealth_check

* fix import

* fix unused imports

* fix ahealth_check

* fix local testing

* test_async_realtime_health_check
This commit is contained in:
Ishaan Jaff 2024-12-28 18:38:54 -08:00 committed by GitHub
parent 4e65722a00
commit 1e06ee3162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 188 additions and 373 deletions

View file

@ -69,6 +69,7 @@ jobs:
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
pip install "pytest-xdist==3.6.1"
pip install "websockets==10.4"
- save_cache:
paths:
- ./venv

View file

@ -2,6 +2,8 @@
Utils used for litellm.transcription() and litellm.atranscription()
"""
import os
from litellm.types.utils import FileTypes
@ -21,3 +23,14 @@ def get_audio_file_name(file_obj: FileTypes) -> str:
return str(file_obj)
else:
return repr(file_obj)
def get_audio_file_for_health_check() -> FileTypes:
"""
Get an audio file for health check
Returns the content of `audio_health_check.wav` in the same directory as this file
"""
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "audio_health_check.wav")
return open(file_path, "rb")

View file

@ -0,0 +1,28 @@
"""
Utils used for litellm.ahealth_check()
"""
def _filter_model_params(model_params: dict) -> dict:
"""Remove 'messages' param from model params."""
return {k: v for k, v in model_params.items() if k != "messages"}
def _create_health_check_response(response_headers: dict) -> dict:
response = {}
if (
response_headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = response_headers[
"x-ratelimit-remaining-requests"
]
if response_headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = response_headers[
"x-ratelimit-remaining-tokens"
]
if response_headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = response_headers["x-ms-region"]
return response

View file

@ -1491,132 +1491,3 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check(
self,
model: Optional[str],
api_key: Optional[str],
api_base: str,
api_version: Optional[str],
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = (
litellm.aclient_session
or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AsyncAzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(
pwd, "../../../tests/gettysburg.wav"
) # proxy address
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
elif mode == "batch":
completion = await client.batches.with_raw_response.list(limit=1) # type: ignore
elif mode == "realtime":
from litellm.realtime_api.main import _realtime_health_check
# create a websocket connection
await _realtime_health_check(
model=model or "",
api_key=api_key,
api_base=api_base,
api_version=api_version,
custom_llm_provider="azure",
)
return {}
else:
raise Exception("mode not set")
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response

View file

@ -1,5 +1,4 @@
import hashlib
import os
import types
from typing import (
Any,
@ -1306,105 +1305,6 @@ class OpenAIChatCompletion(BaseLLM):
return HttpxBinaryResponseContent(response=response.response)
async def ahealth_check(
self,
model: Optional[str],
api_key: Optional[str],
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
api_base: Optional[str] = None,
):
client = AsyncOpenAI(
api_key=api_key,
timeout=timeout,
organization=organization,
base_url=api_base,
)
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(
pwd, "../../../tests/gettysburg.wav"
) # proxy address
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
elif mode == "realtime":
from litellm.realtime_api.main import _realtime_health_check
# create a websocket connection
await _realtime_health_check(
model=model or "",
api_key=api_key,
api_base=api_base or "https://api.openai.com/",
custom_llm_provider="openai",
)
return {}
else:
raise ValueError("mode not set, passed in mode: " + mode)
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
return response
class OpenAIFilesAPI(BaseLLM):
"""

View file

@ -51,6 +51,11 @@ from litellm import ( # type: ignore
get_optional_params,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
from litellm.litellm_core_utils.health_check_utils import (
_create_health_check_response,
_filter_model_params,
)
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.mock_functions import (
mock_embedding,
@ -60,6 +65,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
@ -5117,10 +5123,9 @@ def speech(
##### Health Endpoints #######################
async def ahealth_check_chat_models(
async def ahealth_check_wildcard_models(
model: str, custom_llm_provider: str, model_params: dict
) -> dict:
if "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
)
@ -5133,49 +5138,45 @@ async def ahealth_check_chat_models(
if custom_llm_provider in litellm.models_by_provider:
models = litellm.models_by_provider[custom_llm_provider]
random.shuffle(models) # Shuffle the models list in place
fallback_models = models[
:2
] # Pick the first 2 models from the shuffled list
fallback_models = models[:2] # Pick the first 2 models from the shuffled list
model_params["model"] = cheapest_model
model_params["fallbacks"] = fallback_models
model_params["max_tokens"] = 1
await acompletion(**model_params)
response: dict = {} # args like remaining ratelimit etc.
else: # default to completion calls
model_params["max_tokens"] = 1
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
async def ahealth_check( # noqa: PLR0915
async def ahealth_check(
model_params: dict,
mode: Optional[
Literal[
"chat",
"completion",
"embedding",
"audio_speech",
"audio_transcription",
"image_generation",
"chat",
"batch",
"rerank",
"realtime",
]
] = None,
] = "chat",
prompt: Optional[str] = None,
input: Optional[List] = None,
default_timeout: float = 6000,
):
"""
Support health checks for different providers. Return remaining rate limit, etc.
For azure/openai -> completion.with_raw_response
For rest -> litellm.acompletion()
Returns:
{
"x-ratelimit-remaining-requests": int,
"x-ratelimit-remaining-tokens": int,
"x-ms-region": str,
}
"""
passed_in_mode: Optional[str] = None
try:
model: Optional[str] = model_params.get("model", None)
if model is None:
raise Exception("model not set")
@ -5183,122 +5184,73 @@ async def ahealth_check( # noqa: PLR0915
mode = litellm.model_cost[model].get("mode")
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
if model in litellm.model_cost and mode is None:
mode = litellm.model_cost[model].get("mode")
mode = mode
passed_in_mode = mode
if mode is None:
mode = "chat" # default to chat completion calls
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret_str("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
)
api_base: Optional[str] = (
model_params.get("api_base")
or get_secret_str("AZURE_API_BASE")
or get_secret_str("AZURE_OPENAI_API_BASE")
)
if api_base is None:
raise ValueError(
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
)
api_version = (
model_params.get("api_version")
or get_secret_str("AZURE_API_VERSION")
or get_secret_str("AZURE_OPENAI_API_VERSION")
)
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await azure_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
)
elif (
custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai":
mode = "completion"
response = await openai_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
organization=organization,
)
else:
model_params["cache"] = {
"no-cache": True
} # don't used cached responses for making health check calls
if mode == "embedding":
model_params.pop("messages", None)
model_params["input"] = input
await litellm.aembedding(**model_params)
response = {}
elif mode == "image_generation":
model_params.pop("messages", None)
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
elif mode == "rerank":
model_params.pop("messages", None)
model_params["query"] = prompt
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
else:
response = await ahealth_check_chat_models(
if "*" in model:
return await ahealth_check_wildcard_models(
model=model,
custom_llm_provider=custom_llm_provider,
model_params=model_params,
)
return response
# Map modes to their corresponding health check calls
mode_handlers = {
"chat": lambda: litellm.acompletion(**model_params),
"completion": lambda: litellm.atext_completion(
**_filter_model_params(model_params),
prompt=prompt or "test",
),
"embedding": lambda: litellm.aembedding(
**_filter_model_params(model_params),
input=input or ["test"],
),
"audio_speech": lambda: litellm.aspeech(
**_filter_model_params(model_params),
input=prompt or "test",
voice="alloy",
),
"audio_transcription": lambda: litellm.atranscription(
**_filter_model_params(model_params),
file=get_audio_file_for_health_check(),
),
"image_generation": lambda: litellm.aimage_generation(
**_filter_model_params(model_params),
prompt=prompt,
),
"rerank": lambda: litellm.arerank(
**_filter_model_params(model_params),
query=prompt or "",
documents=["my sample text"],
),
"realtime": lambda: _realtime_health_check(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=model_params.get("api_base", None),
api_key=model_params.get("api_key", None),
api_version=model_params.get("api_version", None),
),
}
if mode in mode_handlers:
_response = await mode_handlers[mode]()
# Only process headers for chat mode
_response_headers: dict = (
getattr(_response, "_hidden_params", {}).get("headers", {}) or {}
)
return _create_health_check_response(_response_headers)
else:
raise Exception(
f"Mode {mode} not supported. See modes here: https://docs.litellm.ai/docs/proxy/health"
)
except Exception as e:
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
if passed_in_mode is None:
if mode is None:
return {
"error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}"
}

View file

@ -118,9 +118,9 @@ async def _arealtime(
async def _realtime_health_check(
model: str,
api_base: str,
custom_llm_provider: str,
api_key: Optional[str],
api_base: Optional[str] = None,
api_version: Optional[str] = None,
):
"""
@ -143,12 +143,14 @@ async def _realtime_health_check(
url: Optional[str] = None
if custom_llm_provider == "azure":
url = azure_realtime._construct_url(
api_base=api_base,
api_base=api_base or "",
model=model,
api_version=api_version or "2024-10-01-preview",
)
elif custom_llm_provider == "openai":
url = openai_realtime._construct_url(api_base=api_base, model=model)
url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/", model=model
)
async with websockets.connect( # type: ignore
url,
extra_headers={

View file

@ -6,6 +6,7 @@ import sys
import traceback
import pytest
from unittest.mock import AsyncMock, patch
sys.path.insert(
0, os.path.abspath("../..")
@ -35,6 +36,19 @@ async def test_azure_health_check():
# asyncio.run(test_azure_health_check())
@pytest.mark.asyncio
async def test_text_completion_health_check():
response = await litellm.ahealth_check(
model_params={"model": "gpt-3.5-turbo-instruct"},
mode="completion",
prompt="What's the weather in SF?",
)
print(f"response: {response}")
assert "x-ratelimit-remaining-tokens" in response
return response
@pytest.mark.asyncio
async def test_azure_embedding_health_check():
response = await litellm.ahealth_check(
@ -128,7 +142,6 @@ async def test_groq_health_check():
mode=None,
prompt="What's 1 + 1?",
input=["test from litellm"],
default_timeout=6000,
)
print(f"response: {response}")
assert response == {}
@ -141,8 +154,6 @@ async def test_cohere_rerank_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "cohere/rerank-english-v3.0",
"query": "Hey, how's it going",
"documents": ["my sample text"],
"api_key": os.getenv("COHERE_API_KEY"),
},
mode="rerank",
@ -154,15 +165,52 @@ async def test_cohere_rerank_health_check():
print(response)
@pytest.mark.asyncio
async def test_audio_speech_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "openai/tts-1",
"api_key": os.getenv("OPENAI_API_KEY"),
},
mode="audio_speech",
prompt="Hey",
)
assert "error" not in response
print(response)
@pytest.mark.asyncio
async def test_audio_transcription_health_check():
litellm.set_verbose = True
response = await litellm.ahealth_check(
model_params={
"model": "openai/whisper-1",
"api_key": os.getenv("OPENAI_API_KEY"),
},
mode="audio_transcription",
)
assert "error" not in response
print(response)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model", ["azure/gpt-4o-realtime-preview", "openai/gpt-4o-realtime-preview"]
)
async def test_realtime_health_check(model):
async def test_async_realtime_health_check(model, mocker):
"""
Test Health Check with Valid models passes
"""
mock_websocket = AsyncMock()
mock_connect = AsyncMock().__aenter__.return_value = mock_websocket
mocker.patch("websockets.connect", return_value=mock_connect)
litellm.set_verbose = True
model_params = {
"model": model,
}