forked from phoenix/litellm-mirror
feat(health_check.py): more detailed health check calls
This commit is contained in:
parent
7ce7516621
commit
3b1685e7c6
6 changed files with 354 additions and 48 deletions
|
@ -700,3 +700,72 @@ class AzureChatCompletion(BaseLLM):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
async def ahealth_check(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
timeout: float,
|
||||||
|
mode: str,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
input: Optional[list] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
):
|
||||||
|
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||||
|
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
|
)
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
completion = None
|
||||||
|
|
||||||
|
if mode == "completion":
|
||||||
|
if messages is None:
|
||||||
|
raise Exception("messages is not set")
|
||||||
|
completion = await client.chat.completions.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "embedding":
|
||||||
|
if input is None:
|
||||||
|
raise Exception("input is not set")
|
||||||
|
completion = await client.embeddings.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
input=input, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "image_generation":
|
||||||
|
if prompt is None:
|
||||||
|
raise Exception("prompt is not set")
|
||||||
|
completion = await client.images.with_raw_response.generate(
|
||||||
|
model=model, # type: ignore
|
||||||
|
prompt=prompt, # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("mode not set")
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
raise Exception("invalid completion response")
|
||||||
|
|
||||||
|
if (
|
||||||
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
return response
|
||||||
|
|
|
@ -717,6 +717,63 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
async def ahealth_check(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
timeout: float,
|
||||||
|
mode: str,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
input: Optional[list] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
):
|
||||||
|
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
|
||||||
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
completion = None
|
||||||
|
|
||||||
|
if mode == "completion":
|
||||||
|
if messages is None:
|
||||||
|
raise Exception("messages is not set")
|
||||||
|
completion = await client.chat.completions.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "embedding":
|
||||||
|
if input is None:
|
||||||
|
raise Exception("input is not set")
|
||||||
|
completion = await client.embeddings.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
input=input, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "image_generation":
|
||||||
|
if prompt is None:
|
||||||
|
raise Exception("prompt is not set")
|
||||||
|
completion = await client.images.with_raw_response.generate(
|
||||||
|
model=model, # type: ignore
|
||||||
|
prompt=prompt, # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("mode not set")
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
raise Exception("invalid completion response")
|
||||||
|
|
||||||
|
if (
|
||||||
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class OpenAITextCompletion(BaseLLM):
|
class OpenAITextCompletion(BaseLLM):
|
||||||
_client_session: httpx.Client
|
_client_session: httpx.Client
|
||||||
|
|
104
litellm/main.py
104
litellm/main.py
|
@ -8,7 +8,7 @@
|
||||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||||
from typing import Any
|
from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -2885,6 +2885,108 @@ def image_generation(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
##### Health Endpoints #######################
|
||||||
|
|
||||||
|
|
||||||
|
async def ahealth_check(
|
||||||
|
model_params: dict,
|
||||||
|
mode: Optional[Literal["completion", "embedding", "image_generation"]] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
input: Optional[List] = None,
|
||||||
|
default_timeout: float = 6000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Support health checks for different providers. Return remaining rate limit, etc.
|
||||||
|
|
||||||
|
For azure/openai -> completion.with_raw_response
|
||||||
|
For rest -> litellm.acompletion()
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model: Optional[str] = model_params.get("model", None)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise Exception("model not set")
|
||||||
|
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
|
mode = mode or "completion" # default to completion calls
|
||||||
|
|
||||||
|
if custom_llm_provider == "azure":
|
||||||
|
api_key = (
|
||||||
|
model_params.get("api_key")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
model_params.get("api_base")
|
||||||
|
or get_secret("AZURE_API_BASE")
|
||||||
|
or get_secret("AZURE_OPENAI_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
model_params.get("api_version")
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
or get_secret("AZURE_OPENAI_API_VERSION")
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = (
|
||||||
|
model_params.get("timeout")
|
||||||
|
or litellm.request_timeout
|
||||||
|
or default_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await azure_chat_completions.ahealth_check(
|
||||||
|
model=model,
|
||||||
|
messages=model_params.get(
|
||||||
|
"messages", None
|
||||||
|
), # Replace with your actual messages list
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
mode=mode,
|
||||||
|
prompt=prompt,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "openai":
|
||||||
|
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
timeout = (
|
||||||
|
model_params.get("timeout")
|
||||||
|
or litellm.request_timeout
|
||||||
|
or default_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_chat_completions.ahealth_check(
|
||||||
|
model=model,
|
||||||
|
messages=model_params.get(
|
||||||
|
"messages", None
|
||||||
|
), # Replace with your actual messages list
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
mode=mode,
|
||||||
|
prompt=prompt,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if mode == "embedding":
|
||||||
|
model_params.pop("messages", None)
|
||||||
|
model_params["input"] = input
|
||||||
|
await litellm.aembedding(**model_params)
|
||||||
|
response = {}
|
||||||
|
elif mode == "image_generation":
|
||||||
|
model_params.pop("messages", None)
|
||||||
|
model_params["prompt"] = prompt
|
||||||
|
await litellm.aimage_generation(**model_params)
|
||||||
|
response = {}
|
||||||
|
else: # default to completion calls
|
||||||
|
await acompletion(**model_params)
|
||||||
|
response = {} # args like remaining ratelimit etc.
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
|
|
@ -12,7 +12,7 @@ from litellm._logging import print_verbose
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
|
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]
|
||||||
|
|
||||||
|
|
||||||
def _get_random_llm_message():
|
def _get_random_llm_message():
|
||||||
|
@ -35,55 +35,20 @@ async def _perform_health_check(model_list: list):
|
||||||
"""
|
"""
|
||||||
Perform a health check for each model in the list.
|
Perform a health check for each model in the list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _check_img_gen_model(model_params: dict):
|
|
||||||
model_params.pop("messages", None)
|
|
||||||
model_params["prompt"] = "test from litellm"
|
|
||||||
try:
|
|
||||||
await litellm.aimage_generation(**model_params)
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(
|
|
||||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _check_embedding_model(model_params: dict):
|
|
||||||
model_params.pop("messages", None)
|
|
||||||
model_params["input"] = ["test from litellm"]
|
|
||||||
try:
|
|
||||||
await litellm.aembedding(**model_params)
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(
|
|
||||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _check_model(model_params: dict):
|
|
||||||
try:
|
|
||||||
await litellm.acompletion(**model_params)
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(
|
|
||||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
litellm_params = model["litellm_params"]
|
litellm_params = model["litellm_params"]
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
litellm_params["messages"] = _get_random_llm_message()
|
litellm_params["messages"] = _get_random_llm_message()
|
||||||
|
mode = model_info.get("mode", None)
|
||||||
if model_info.get("mode", None) == "embedding":
|
tasks.append(
|
||||||
# this is an embedding model
|
litellm.ahealth_check(
|
||||||
tasks.append(_check_embedding_model(litellm_params))
|
litellm_params,
|
||||||
elif model_info.get("mode", None) == "image_generation":
|
mode=mode,
|
||||||
tasks.append(_check_img_gen_model(litellm_params))
|
prompt="test from litellm",
|
||||||
else:
|
input=["test from litellm"],
|
||||||
tasks.append(_check_model(litellm_params))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
@ -93,8 +58,10 @@ async def _perform_health_check(model_list: list):
|
||||||
for is_healthy, model in zip(results, model_list):
|
for is_healthy, model in zip(results, model_list):
|
||||||
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
|
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
|
||||||
|
|
||||||
if is_healthy:
|
if isinstance(is_healthy, dict) and "error" not in is_healthy:
|
||||||
healthy_endpoints.append(cleaned_litellm_params)
|
healthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
|
||||||
|
elif isinstance(is_healthy, dict):
|
||||||
|
unhealthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
|
||||||
else:
|
else:
|
||||||
unhealthy_endpoints.append(cleaned_litellm_params)
|
unhealthy_endpoints.append(cleaned_litellm_params)
|
||||||
|
|
||||||
|
|
|
@ -672,6 +672,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
router_params["model_list"] = model_list
|
router_params["model_list"] = model_list
|
||||||
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
|
### LOAD FROM os.environ/ ###
|
||||||
|
for k, v in model["litellm_params"].items():
|
||||||
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
model["litellm_params"][k] = litellm.get_secret(v)
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||||
|
|
107
litellm/tests/test_health_check.py
Normal file
107
litellm/tests/test_health_check.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests if ahealth_check() actually works
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm, asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
},
|
||||||
|
mode="completion",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert "x-ratelimit-remaining-tokens" in response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_azure_health_check())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_embedding_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
},
|
||||||
|
mode="embedding",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert "x-ratelimit-remaining-tokens" in response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_img_gen_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "dall-e-3",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
mode="image_generation",
|
||||||
|
prompt="cute baby sea otter",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert isinstance(response, dict) and "error" not in response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_openai_img_gen_health_check())
|
||||||
|
|
||||||
|
|
||||||
|
async def test_azure_img_gen_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "azure/",
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": "2023-06-01-preview",
|
||||||
|
},
|
||||||
|
mode="image_generation",
|
||||||
|
prompt="cute baby sea otter",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, dict) and "error" not in response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_azure_img_gen_health_check())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_embedding_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "sagemaker/berri-benchmarking-gpt-j-6b-fp16",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
},
|
||||||
|
mode="embedding",
|
||||||
|
input=["test from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_sagemaker_embedding_health_check())
|
Loading…
Add table
Add a link
Reference in a new issue