mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5431 from BerriAI/litellm_Add_fireworks_ai_health_check
[Fix-Proxy] /health check for provider wildcard models (fireworks/*)
This commit is contained in:
commit
5851a8f901
4 changed files with 71 additions and 2 deletions
|
@ -363,6 +363,7 @@ ai21_models: List = []
|
|||
nlp_cloud_models: List = []
|
||||
aleph_alpha_models: List = []
|
||||
bedrock_models: List = []
|
||||
fireworks_ai_models: List = []
|
||||
deepinfra_models: List = []
|
||||
perplexity_models: List = []
|
||||
watsonx_models: List = []
|
||||
|
@ -423,6 +424,8 @@ for key, value in model_cost.items():
|
|||
watsonx_models.append(key)
|
||||
elif value.get("litellm_provider") == "gemini":
|
||||
gemini_models.append(key)
|
||||
elif value.get("litellm_provider") == "fireworks_ai":
|
||||
fireworks_ai_models.append(key)
|
||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||
openai_compatible_endpoints: List = [
|
||||
"api.perplexity.ai",
|
||||
|
@ -726,6 +729,7 @@ models_by_provider: dict = {
|
|||
"maritalk": maritalk_models,
|
||||
"watsonx": watsonx_models,
|
||||
"gemini": gemini_models,
|
||||
"fireworks_ai": fireworks_ai_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
|
@ -26,3 +28,26 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
|||
extra_body["metadata"]["prompt"] = _prompt.__dict__
|
||||
|
||||
return extra_body
|
||||
|
||||
|
||||
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
||||
"""
|
||||
Pick a random model from the LLM provider.
|
||||
"""
|
||||
if custom_llm_provider not in litellm.models_by_provider:
|
||||
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
|
||||
|
||||
known_models = litellm.models_by_provider.get(custom_llm_provider, [])
|
||||
min_cost = float("inf")
|
||||
cheapest_model = None
|
||||
for model in known_models:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
|
||||
"output_cost_per_token", 0
|
||||
)
|
||||
if _cost < min_cost:
|
||||
min_cost = _cost
|
||||
cheapest_model = model
|
||||
return cheapest_model
|
||||
|
|
|
@ -5076,6 +5076,18 @@ async def ahealth_check(
|
|||
model_params["prompt"] = prompt
|
||||
await litellm.aimage_generation(**model_params)
|
||||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_params["model"] = cheapest_model
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
#### What this tests ####
|
||||
# This tests if ahealth_check() actually works
|
||||
|
||||
import sys, os
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm, asyncio
|
||||
import asyncio
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -105,3 +109,27 @@ async def test_sagemaker_embedding_health_check():
|
|||
|
||||
|
||||
# asyncio.run(test_sagemaker_embedding_health_check())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_health_check():
|
||||
"""
|
||||
This should not fail
|
||||
|
||||
ensure that provider wildcard model passes health check
|
||||
"""
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"),
|
||||
"model": "fireworks_ai/*",
|
||||
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
|
||||
},
|
||||
mode=None,
|
||||
prompt="What's 1 + 1?",
|
||||
input=["test from litellm"],
|
||||
default_timeout=6000,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert response == {}
|
||||
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue