feat(health_check.py): more detailed health check calls

This commit is contained in:
Krrish Dholakia 2023-12-28 09:12:57 +05:30
parent 0987fec75a
commit 2285282ef8
6 changed files with 354 additions and 48 deletions

View file

@ -8,7 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from typing import Any, Literal, Union
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
@ -2885,6 +2885,108 @@ def image_generation(
return model_response
##### Health Endpoints #######################
async def ahealth_check(
model_params: dict,
mode: Optional[Literal["completion", "embedding", "image_generation"]] = None,
prompt: Optional[str] = None,
input: Optional[List] = None,
default_timeout: float = 6000,
):
"""
Support health checks for different providers. Return remaining rate limit, etc.
For azure/openai -> completion.with_raw_response
For rest -> litellm.acompletion()
"""
try:
model: Optional[str] = model_params.get("model", None)
if model is None:
raise Exception("model not set")
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
mode = mode or "completion" # default to completion calls
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret("AZURE_API_KEY")
or get_secret("AZURE_OPENAI_API_KEY")
)
api_base = (
model_params.get("api_base")
or get_secret("AZURE_API_BASE")
or get_secret("AZURE_OPENAI_API_BASE")
)
api_version = (
model_params.get("api_version")
or get_secret("AZURE_API_VERSION")
or get_secret("AZURE_OPENAI_API_VERSION")
)
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await azure_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
)
elif custom_llm_provider == "openai":
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
timeout = (
model_params.get("timeout")
or litellm.request_timeout
or default_timeout
)
response = await openai_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
"messages", None
), # Replace with your actual messages list
api_key=api_key,
timeout=timeout,
mode=mode,
prompt=prompt,
input=input,
)
else:
if mode == "embedding":
model_params.pop("messages", None)
model_params["input"] = input
await litellm.aembedding(**model_params)
response = {}
elif mode == "image_generation":
model_params.pop("messages", None)
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
else: # default to completion calls
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
except Exception as e:
return {"error": str(e)}
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):