forked from phoenix/litellm-mirror
feat(health_check.py): more detailed health check calls
This commit is contained in:
parent
7ce7516621
commit
3b1685e7c6
6 changed files with 354 additions and 48 deletions
|
@ -700,3 +700,72 @@ class AzureChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
):
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||
)
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
completion = None
|
||||
|
||||
if mode == "completion":
|
||||
if messages is None:
|
||||
raise Exception("messages is not set")
|
||||
completion = await client.chat.completions.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
elif mode == "embedding":
|
||||
if input is None:
|
||||
raise Exception("input is not set")
|
||||
completion = await client.embeddings.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=input, # type: ignore
|
||||
)
|
||||
elif mode == "image_generation":
|
||||
if prompt is None:
|
||||
raise Exception("prompt is not set")
|
||||
completion = await client.images.with_raw_response.generate(
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise Exception("mode not set")
|
||||
response = {}
|
||||
|
||||
if completion is None or not hasattr(completion, "headers"):
|
||||
raise Exception("invalid completion response")
|
||||
|
||||
if (
|
||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||
): # not provided for dall-e requests
|
||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
return response
|
||||
|
|
|
@ -717,6 +717,63 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: str,
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
):
|
||||
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
completion = None
|
||||
|
||||
if mode == "completion":
|
||||
if messages is None:
|
||||
raise Exception("messages is not set")
|
||||
completion = await client.chat.completions.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
elif mode == "embedding":
|
||||
if input is None:
|
||||
raise Exception("input is not set")
|
||||
completion = await client.embeddings.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=input, # type: ignore
|
||||
)
|
||||
elif mode == "image_generation":
|
||||
if prompt is None:
|
||||
raise Exception("prompt is not set")
|
||||
completion = await client.images.with_raw_response.generate(
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise Exception("mode not set")
|
||||
response = {}
|
||||
|
||||
if completion is None or not hasattr(completion, "headers"):
|
||||
raise Exception("invalid completion response")
|
||||
|
||||
if (
|
||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||
): # not provided for dall-e requests
|
||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
return response
|
||||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
|
|
104
litellm/main.py
104
litellm/main.py
|
@ -8,7 +8,7 @@
|
|||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any
|
||||
from typing import Any, Literal, Union
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
from copy import deepcopy
|
||||
|
@ -2885,6 +2885,108 @@ def image_generation(
|
|||
return model_response
|
||||
|
||||
|
||||
##### Health Endpoints #######################
|
||||
|
||||
|
||||
async def ahealth_check(
|
||||
model_params: dict,
|
||||
mode: Optional[Literal["completion", "embedding", "image_generation"]] = None,
|
||||
prompt: Optional[str] = None,
|
||||
input: Optional[List] = None,
|
||||
default_timeout: float = 6000,
|
||||
):
|
||||
"""
|
||||
Support health checks for different providers. Return remaining rate limit, etc.
|
||||
|
||||
For azure/openai -> completion.with_raw_response
|
||||
For rest -> litellm.acompletion()
|
||||
"""
|
||||
try:
|
||||
model: Optional[str] = model_params.get("model", None)
|
||||
|
||||
if model is None:
|
||||
raise Exception("model not set")
|
||||
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
mode = mode or "completion" # default to completion calls
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
api_key = (
|
||||
model_params.get("api_key")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
model_params.get("api_base")
|
||||
or get_secret("AZURE_API_BASE")
|
||||
or get_secret("AZURE_OPENAI_API_BASE")
|
||||
)
|
||||
|
||||
api_version = (
|
||||
model_params.get("api_version")
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret("AZURE_OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
or litellm.request_timeout
|
||||
or default_timeout
|
||||
)
|
||||
|
||||
response = await azure_chat_completions.ahealth_check(
|
||||
model=model,
|
||||
messages=model_params.get(
|
||||
"messages", None
|
||||
), # Replace with your actual messages list
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
or litellm.request_timeout
|
||||
or default_timeout
|
||||
)
|
||||
|
||||
response = await openai_chat_completions.ahealth_check(
|
||||
model=model,
|
||||
messages=model_params.get(
|
||||
"messages", None
|
||||
), # Replace with your actual messages list
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
)
|
||||
else:
|
||||
if mode == "embedding":
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = input
|
||||
await litellm.aembedding(**model_params)
|
||||
response = {}
|
||||
elif mode == "image_generation":
|
||||
model_params.pop("messages", None)
|
||||
model_params["prompt"] = prompt
|
||||
await litellm.aimage_generation(**model_params)
|
||||
response = {}
|
||||
else: # default to completion calls
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||
def print_verbose(print_statement):
|
||||
|
|
|
@ -12,7 +12,7 @@ from litellm._logging import print_verbose
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
|
||||
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]
|
||||
|
||||
|
||||
def _get_random_llm_message():
|
||||
|
@ -35,55 +35,20 @@ async def _perform_health_check(model_list: list):
|
|||
"""
|
||||
Perform a health check for each model in the list.
|
||||
"""
|
||||
|
||||
async def _check_img_gen_model(model_params: dict):
|
||||
model_params.pop("messages", None)
|
||||
model_params["prompt"] = "test from litellm"
|
||||
try:
|
||||
await litellm.aimage_generation(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _check_embedding_model(model_params: dict):
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = ["test from litellm"]
|
||||
try:
|
||||
await litellm.aembedding(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _check_model(model_params: dict):
|
||||
try:
|
||||
await litellm.acompletion(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
tasks = []
|
||||
for model in model_list:
|
||||
litellm_params = model["litellm_params"]
|
||||
model_info = model.get("model_info", {})
|
||||
litellm_params["messages"] = _get_random_llm_message()
|
||||
|
||||
if model_info.get("mode", None) == "embedding":
|
||||
# this is an embedding model
|
||||
tasks.append(_check_embedding_model(litellm_params))
|
||||
elif model_info.get("mode", None) == "image_generation":
|
||||
tasks.append(_check_img_gen_model(litellm_params))
|
||||
else:
|
||||
tasks.append(_check_model(litellm_params))
|
||||
mode = model_info.get("mode", None)
|
||||
tasks.append(
|
||||
litellm.ahealth_check(
|
||||
litellm_params,
|
||||
mode=mode,
|
||||
prompt="test from litellm",
|
||||
input=["test from litellm"],
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
@ -93,8 +58,10 @@ async def _perform_health_check(model_list: list):
|
|||
for is_healthy, model in zip(results, model_list):
|
||||
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
|
||||
|
||||
if is_healthy:
|
||||
healthy_endpoints.append(cleaned_litellm_params)
|
||||
if isinstance(is_healthy, dict) and "error" not in is_healthy:
|
||||
healthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
|
||||
elif isinstance(is_healthy, dict):
|
||||
unhealthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
|
||||
else:
|
||||
unhealthy_endpoints.append(cleaned_litellm_params)
|
||||
|
||||
|
|
|
@ -672,6 +672,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
router_params["model_list"] = model_list
|
||||
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
||||
for model in model_list:
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
|
|
107
litellm/tests/test_health_check.py
Normal file
107
litellm/tests/test_health_check.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
#### What this tests ####
|
||||
# This tests if ahealth_check() actually works
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm, asyncio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
mode="completion",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
assert "x-ratelimit-remaining-tokens" in response
|
||||
return response
|
||||
|
||||
|
||||
# asyncio.run(test_azure_health_check())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_embedding_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "azure/azure-embedding-model",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
mode="embedding",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
assert "x-ratelimit-remaining-tokens" in response
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_img_gen_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "dall-e-3",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
mode="image_generation",
|
||||
prompt="cute baby sea otter",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
assert isinstance(response, dict) and "error" not in response
|
||||
return response
|
||||
|
||||
|
||||
# asyncio.run(test_openai_img_gen_health_check())
|
||||
|
||||
|
||||
async def test_azure_img_gen_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "azure/",
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": "2023-06-01-preview",
|
||||
},
|
||||
mode="image_generation",
|
||||
prompt="cute baby sea otter",
|
||||
)
|
||||
|
||||
assert isinstance(response, dict) and "error" not in response
|
||||
return response
|
||||
|
||||
|
||||
# asyncio.run(test_azure_img_gen_health_check())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sagemaker_embedding_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "sagemaker/berri-benchmarking-gpt-j-6b-fp16",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||
},
|
||||
mode="embedding",
|
||||
input=["test from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
assert isinstance(response, dict)
|
||||
return response
|
||||
|
||||
|
||||
# asyncio.run(test_sagemaker_embedding_health_check())
|
Loading…
Add table
Add a link
Reference in a new issue