diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 8877c043f..9379f5042 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -700,3 +700,72 @@ class AzureChatCompletion(BaseLLM): import traceback raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) + + async def ahealth_check( + self, + model: Optional[str], + api_key: str, + api_base: str, + api_version: str, + timeout: float, + mode: str, + messages: Optional[list] = None, + input: Optional[list] = None, + prompt: Optional[str] = None, + ): + client_session = litellm.aclient_session or httpx.AsyncClient( + transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls + ) + client = AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=api_base, + api_key=api_key, + timeout=timeout, + http_client=client_session, + ) + + if model is None and mode != "image_generation": + raise Exception("model is not set") + + completion = None + + if mode == "completion": + if messages is None: + raise Exception("messages is not set") + completion = await client.chat.completions.with_raw_response.create( + model=model, # type: ignore + messages=messages, # type: ignore + ) + elif mode == "embedding": + if input is None: + raise Exception("input is not set") + completion = await client.embeddings.with_raw_response.create( + model=model, # type: ignore + input=input, # type: ignore + ) + elif mode == "image_generation": + if prompt is None: + raise Exception("prompt is not set") + completion = await client.images.with_raw_response.generate( + model=model, # type: ignore + prompt=prompt, # type: ignore + ) + else: + raise Exception("mode not set") + response = {} + + if completion is None or not hasattr(completion, "headers"): + raise Exception("invalid completion response") + + if ( + completion.headers.get("x-ratelimit-remaining-requests", None) is not None + ): # not provided for dall-e requests + response["x-ratelimit-remaining-requests"] = completion.headers[ + "x-ratelimit-remaining-requests" + ] + + if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: + response["x-ratelimit-remaining-tokens"] = completion.headers[ + "x-ratelimit-remaining-tokens" + ] + return response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index c887eb405..b1947bad2 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -717,6 +717,63 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError(status_code=500, message=traceback.format_exc()) + async def ahealth_check( + self, + model: Optional[str], + api_key: str, + timeout: float, + mode: str, + messages: Optional[list] = None, + input: Optional[list] = None, + prompt: Optional[str] = None, + ): + client = AsyncOpenAI(api_key=api_key, timeout=timeout) + if model is None and mode != "image_generation": + raise Exception("model is not set") + + completion = None + + if mode == "completion": + if messages is None: + raise Exception("messages is not set") + completion = await client.chat.completions.with_raw_response.create( + model=model, # type: ignore + messages=messages, # type: ignore + ) + elif mode == "embedding": + if input is None: + raise Exception("input is not set") + completion = await client.embeddings.with_raw_response.create( + model=model, # type: ignore + input=input, # type: ignore + ) + elif mode == "image_generation": + if prompt is None: + raise Exception("prompt is not set") + completion = await client.images.with_raw_response.generate( + model=model, # type: ignore + prompt=prompt, # type: ignore + ) + else: + raise Exception("mode not set") + response = {} + + if completion is None or not hasattr(completion, "headers"): + raise Exception("invalid completion response") + + if ( + completion.headers.get("x-ratelimit-remaining-requests", None) is not None + ): # not provided for dall-e requests + response["x-ratelimit-remaining-requests"] = completion.headers[ + "x-ratelimit-remaining-requests" + ] + + if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: + response["x-ratelimit-remaining-tokens"] = completion.headers[ + "x-ratelimit-remaining-tokens" + ] + return response + class OpenAITextCompletion(BaseLLM): _client_session: httpx.Client diff --git a/litellm/main.py b/litellm/main.py index 9b6e57d7c..f352f19c6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -8,7 +8,7 @@ # Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading -from typing import Any +from typing import Any, Literal, Union from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy @@ -2885,6 +2885,108 @@ def image_generation( return model_response +##### Health Endpoints ####################### + + +async def ahealth_check( + model_params: dict, + mode: Optional[Literal["completion", "embedding", "image_generation"]] = None, + prompt: Optional[str] = None, + input: Optional[List] = None, + default_timeout: float = 6000, +): + """ + Support health checks for different providers. Return remaining rate limit, etc. + + For azure/openai -> completion.with_raw_response + For rest -> litellm.acompletion() + """ + try: + model: Optional[str] = model_params.get("model", None) + + if model is None: + raise Exception("model not set") + + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + mode = mode or "completion" # default to completion calls + + if custom_llm_provider == "azure": + api_key = ( + model_params.get("api_key") + or get_secret("AZURE_API_KEY") + or get_secret("AZURE_OPENAI_API_KEY") + ) + + api_base = ( + model_params.get("api_base") + or get_secret("AZURE_API_BASE") + or get_secret("AZURE_OPENAI_API_BASE") + ) + + api_version = ( + model_params.get("api_version") + or get_secret("AZURE_API_VERSION") + or get_secret("AZURE_OPENAI_API_VERSION") + ) + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + response = await azure_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + ) + elif custom_llm_provider == "openai": + api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + response = await openai_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + ) + else: + if mode == "embedding": + model_params.pop("messages", None) + model_params["input"] = input + await litellm.aembedding(**model_params) + response = {} + elif mode == "image_generation": + model_params.pop("messages", None) + model_params["prompt"] = prompt + await litellm.aimage_generation(**model_params) + response = {} + else: # default to completion calls + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. + return response + except Exception as e: + return {"error": str(e)} + + ####### HELPER FUNCTIONS ################ ## Set verbose to true -> ```litellm.set_verbose = True``` def print_verbose(print_statement): diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index 53dc2cf72..b05bd4b6a 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -12,7 +12,7 @@ from litellm._logging import print_verbose logger = logging.getLogger(__name__) -ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"] +ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"] def _get_random_llm_message(): @@ -35,55 +35,20 @@ async def _perform_health_check(model_list: list): """ Perform a health check for each model in the list. """ - - async def _check_img_gen_model(model_params: dict): - model_params.pop("messages", None) - model_params["prompt"] = "test from litellm" - try: - await litellm.aimage_generation(**model_params) - except Exception as e: - print_verbose( - f"Health check failed for model {model_params['model']}. Error: {e}" - ) - return False - return True - - async def _check_embedding_model(model_params: dict): - model_params.pop("messages", None) - model_params["input"] = ["test from litellm"] - try: - await litellm.aembedding(**model_params) - except Exception as e: - print_verbose( - f"Health check failed for model {model_params['model']}. Error: {e}" - ) - return False - return True - - async def _check_model(model_params: dict): - try: - await litellm.acompletion(**model_params) - except Exception as e: - print_verbose( - f"Health check failed for model {model_params['model']}. Error: {e}" - ) - return False - - return True - tasks = [] for model in model_list: litellm_params = model["litellm_params"] model_info = model.get("model_info", {}) litellm_params["messages"] = _get_random_llm_message() - - if model_info.get("mode", None) == "embedding": - # this is an embedding model - tasks.append(_check_embedding_model(litellm_params)) - elif model_info.get("mode", None) == "image_generation": - tasks.append(_check_img_gen_model(litellm_params)) - else: - tasks.append(_check_model(litellm_params)) + mode = model_info.get("mode", None) + tasks.append( + litellm.ahealth_check( + litellm_params, + mode=mode, + prompt="test from litellm", + input=["test from litellm"], + ) + ) results = await asyncio.gather(*tasks) @@ -93,8 +58,10 @@ async def _perform_health_check(model_list: list): for is_healthy, model in zip(results, model_list): cleaned_litellm_params = _clean_litellm_params(model["litellm_params"]) - if is_healthy: - healthy_endpoints.append(cleaned_litellm_params) + if isinstance(is_healthy, dict) and "error" not in is_healthy: + healthy_endpoints.append({**cleaned_litellm_params, **is_healthy}) + elif isinstance(is_healthy, dict): + unhealthy_endpoints.append({**cleaned_litellm_params, **is_healthy}) else: unhealthy_endpoints.append(cleaned_litellm_params) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 754230701..3b5d8c6ac 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -672,6 +672,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): router_params["model_list"] = model_list print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m") for model in model_list: + ### LOAD FROM os.environ/ ### + for k, v in model["litellm_params"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + model["litellm_params"][k] = litellm.get_secret(v) print(f"\033[32m {model.get('model_name', '')}\033[0m") litellm_model_name = model["litellm_params"]["model"] litellm_model_api_base = model["litellm_params"].get("api_base", None) diff --git a/litellm/tests/test_health_check.py b/litellm/tests/test_health_check.py new file mode 100644 index 000000000..9a59ca315 --- /dev/null +++ b/litellm/tests/test_health_check.py @@ -0,0 +1,107 @@ +#### What this tests #### +# This tests if ahealth_check() actually works + +import sys, os +import traceback +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm, asyncio + + +@pytest.mark.asyncio +async def test_azure_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + mode="completion", + ) + print(f"response: {response}") + + assert "x-ratelimit-remaining-tokens" in response + return response + + +# asyncio.run(test_azure_health_check()) + + +@pytest.mark.asyncio +async def test_azure_embedding_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "azure/azure-embedding-model", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + mode="embedding", + ) + print(f"response: {response}") + + assert "x-ratelimit-remaining-tokens" in response + return response + + +@pytest.mark.asyncio +async def test_openai_img_gen_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "dall-e-3", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + mode="image_generation", + prompt="cute baby sea otter", + ) + print(f"response: {response}") + + assert isinstance(response, dict) and "error" not in response + return response + + +# asyncio.run(test_openai_img_gen_health_check()) + + +async def test_azure_img_gen_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "azure/", + "api_base": os.getenv("AZURE_API_BASE"), + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": "2023-06-01-preview", + }, + mode="image_generation", + prompt="cute baby sea otter", + ) + + assert isinstance(response, dict) and "error" not in response + return response + + +# asyncio.run(test_azure_img_gen_health_check()) + + +@pytest.mark.asyncio +async def test_sagemaker_embedding_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "sagemaker/berri-benchmarking-gpt-j-6b-fp16", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], + }, + mode="embedding", + input=["test from litellm"], + ) + print(f"response: {response}") + + assert isinstance(response, dict) + return response + + +# asyncio.run(test_sagemaker_embedding_health_check())