feat(health_check.py): more detailed health check calls

This commit is contained in:
Krrish Dholakia 2023-12-28 09:12:57 +05:30
parent 7ce7516621
commit 3b1685e7c6
6 changed files with 354 additions and 48 deletions

View file

@ -700,3 +700,72 @@ class AzureChatCompletion(BaseLLM):
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
):
client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
)
client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
else:
raise Exception("mode not set")
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
return response

View file

@ -717,6 +717,63 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
):
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
else:
raise Exception("mode not set")
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
return response
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client

View file

@ -8,7 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from typing import Any, Literal, Union
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
@ -2885,6 +2885,108 @@ def image_generation(
return model_response
##### Health Endpoints #######################
async def ahealth_check(
model_params: dict,
mode: Optional[Literal["completion", "embedding", "image_generation"]] = None,
prompt: Optional[str] = None,
input: Optional[List] = None,
default_timeout: float = 6000,
):
"""
Support health checks for different providers. Return remaining rate limit, etc.
For azure/openai -> completion.with_raw_response
For rest -> litellm.acompletion()
"""
try:
model: Optional[str] = model_params.get("model", None)
if model is None:
raise Exception("model not set")
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
mode = mode or "completion" # default to completion calls
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret("AZURE_API_KEY")
or get_secret("AZURE_OPENAI_API_KEY")
)
api_base = (
model_params.get("api_base")
or get_secret("AZURE_API_BASE")
or get_secret("AZURE_OPENAI_API_BASE")
)
api_version = (
model_params.get("api_version")
or get_secret("AZURE_API_VERSION")
or get_secret("AZURE_OPENAI_API_VERSION")
)
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await azure_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
)
elif custom_llm_provider == "openai":
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await openai_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
)
else:
if mode == "embedding":
model_params.pop("messages", None)
model_params["input"] = input
await litellm.aembedding(**model_params)
response = {}
elif mode == "image_generation":
model_params.pop("messages", None)
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
else: # default to completion calls
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
except Exception as e:
return {"error": str(e)}
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):

View file

@ -12,7 +12,7 @@ from litellm._logging import print_verbose
logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]
def _get_random_llm_message():
@ -35,55 +35,20 @@ async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None)
model_params["prompt"] = "test from litellm"
try:
await litellm.aimage_generation(**model_params)
except Exception as e:
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
async def _check_embedding_model(model_params: dict):
model_params.pop("messages", None)
model_params["input"] = ["test from litellm"]
try:
await litellm.aembedding(**model_params)
except Exception as e:
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
tasks = []
for model in model_list:
litellm_params = model["litellm_params"]
model_info = model.get("model_info", {})
litellm_params["messages"] = _get_random_llm_message()
if model_info.get("mode", None) == "embedding":
# this is an embedding model
tasks.append(_check_embedding_model(litellm_params))
elif model_info.get("mode", None) == "image_generation":
tasks.append(_check_img_gen_model(litellm_params))
else:
tasks.append(_check_model(litellm_params))
mode = model_info.get("mode", None)
tasks.append(
litellm.ahealth_check(
litellm_params,
mode=mode,
prompt="test from litellm",
input=["test from litellm"],
)
)
results = await asyncio.gather(*tasks)
@ -93,8 +58,10 @@ async def _perform_health_check(model_list: list):
for is_healthy, model in zip(results, model_list):
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
if is_healthy:
healthy_endpoints.append(cleaned_litellm_params)
if isinstance(is_healthy, dict) and "error" not in is_healthy:
healthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
elif isinstance(is_healthy, dict):
unhealthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
else:
unhealthy_endpoints.append(cleaned_litellm_params)

View file

@ -672,6 +672,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
router_params["model_list"] = model_list
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m")
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)

View file

@ -0,0 +1,107 @@
#### What this tests ####
# This tests if ahealth_check() actually works
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm, asyncio
@pytest.mark.asyncio
async def test_azure_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
mode="completion",
)
print(f"response: {response}")
assert "x-ratelimit-remaining-tokens" in response
return response
# asyncio.run(test_azure_health_check())
@pytest.mark.asyncio
async def test_azure_embedding_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "azure/azure-embedding-model",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
mode="embedding",
)
print(f"response: {response}")
assert "x-ratelimit-remaining-tokens" in response
return response
@pytest.mark.asyncio
async def test_openai_img_gen_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "dall-e-3",
"api_key": os.getenv("OPENAI_API_KEY"),
},
mode="image_generation",
prompt="cute baby sea otter",
)
print(f"response: {response}")
assert isinstance(response, dict) and "error" not in response
return response
# asyncio.run(test_openai_img_gen_health_check())
async def test_azure_img_gen_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "azure/",
"api_base": os.getenv("AZURE_API_BASE"),
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": "2023-06-01-preview",
},
mode="image_generation",
prompt="cute baby sea otter",
)
assert isinstance(response, dict) and "error" not in response
return response
# asyncio.run(test_azure_img_gen_health_check())
@pytest.mark.asyncio
async def test_sagemaker_embedding_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "sagemaker/berri-benchmarking-gpt-j-6b-fp16",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
},
mode="embedding",
input=["test from litellm"],
)
print(f"response: {response}")
assert isinstance(response, dict)
return response
# asyncio.run(test_sagemaker_embedding_health_check())