diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index b05bd4b6a..a20ec06e5 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -86,7 +86,12 @@ async def perform_health_check( return [], [] if model is not None: - model_list = [x for x in model_list if x["litellm_params"]["model"] == model] + _new_model_list = [ + x for x in model_list if x["litellm_params"]["model"] == model + ] + if _new_model_list == []: + _new_model_list = [x for x in model_list if x["model_name"] == model] + model_list = _new_model_list healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list) diff --git a/tests/test_models.py b/tests/test_models.py index eb0cbbe5f..b1e8fab59 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,11 @@ import pytest import asyncio import aiohttp +import os +import dotenv +from dotenv import load_dotenv + +load_dotenv() async def generate_key(session, models=[]): @@ -102,14 +107,14 @@ async def get_model_info(session, key): return await response.json() -async def chat_completion(session, key): +async def chat_completion(session, key, model="azure-gpt-3.5"): url = "http://0.0.0.0:4000/chat/completions" headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", } data = { - "model": "azure-gpt-3.5", + "model": model, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, @@ -177,3 +182,130 @@ async def test_add_and_delete_models(): await asyncio.sleep(60) await chat_completion(session=session, key=key) await delete_model(session=session, model_id=model_id) + + +async def add_model_for_health_checking(session, model_id="123"): + url = "http://0.0.0.0:4000/model/new" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + + data = { + "model_name": f"azure-model-health-check-{model_id}", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "api_version": "2023-05-15", + }, + "model_info": {"id": model_id}, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Add models {response_text}") + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +async def get_model_info(session, key): + url = "http://0.0.0.0:4000/model/info" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print("response from /model/info") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +async def get_model_info_v2(session, key): + url = "http://0.0.0.0:4000/v2/model/info" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print("response from v2/model/info") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +async def get_model_health(session, key, model_name): + url = "http://0.0.0.0:4000/health?model=" + model_name + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.json() + print("response from /health?model=", model_name) + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return response_text + + +@pytest.mark.asyncio +async def test_add_model_run_health(): + """ + Add model + Call /model/info and v2/model/info + -> Admin UI calls v2/model/info + Call /chat/completions + Call /health + -> Ensure the health check for the endpoint is working as expected + """ + import uuid + + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + model_id = str(uuid.uuid4()) + model_name = f"azure-model-health-check-{model_id}" + print("adding model", model_name) + await add_model_for_health_checking(session=session, model_id=model_id) + await asyncio.sleep(10) + print("calling /model/info") + await get_model_info(session=session, key=key) + print("calling v2/model/info") + await get_model_info_v2(session=session, key=key) + + print("calling /chat/completions -> expect to work") + await chat_completion(session=session, key=key, model=model_name) + + print("calling /health?model=", model_name) + _health_info = await get_model_health( + session=session, key=key, model_name="azure/chatgpt-v-2" + ) + _healthy_endpooint = _health_info["healthy_endpoints"][0] + + assert _health_info["healthy_count"] == 1 + assert ( + _healthy_endpooint["model"] == "azure/chatgpt-v-2" + ) # this is the model that got added + + # cleanup + await delete_model(session=session, model_id=model_id)