mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5431 from BerriAI/litellm_Add_fireworks_ai_health_check
[Fix-Proxy] /health check for provider wildcard models (fireworks/*)
This commit is contained in:
commit
5851a8f901
4 changed files with 71 additions and 2 deletions
|
@ -363,6 +363,7 @@ ai21_models: List = []
|
||||||
nlp_cloud_models: List = []
|
nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
bedrock_models: List = []
|
bedrock_models: List = []
|
||||||
|
fireworks_ai_models: List = []
|
||||||
deepinfra_models: List = []
|
deepinfra_models: List = []
|
||||||
perplexity_models: List = []
|
perplexity_models: List = []
|
||||||
watsonx_models: List = []
|
watsonx_models: List = []
|
||||||
|
@ -423,6 +424,8 @@ for key, value in model_cost.items():
|
||||||
watsonx_models.append(key)
|
watsonx_models.append(key)
|
||||||
elif value.get("litellm_provider") == "gemini":
|
elif value.get("litellm_provider") == "gemini":
|
||||||
gemini_models.append(key)
|
gemini_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "fireworks_ai":
|
||||||
|
fireworks_ai_models.append(key)
|
||||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||||
openai_compatible_endpoints: List = [
|
openai_compatible_endpoints: List = [
|
||||||
"api.perplexity.ai",
|
"api.perplexity.ai",
|
||||||
|
@ -726,6 +729,7 @@ models_by_provider: dict = {
|
||||||
"maritalk": maritalk_models,
|
"maritalk": maritalk_models,
|
||||||
"watsonx": watsonx_models,
|
"watsonx": watsonx_models,
|
||||||
"gemini": gemini_models,
|
"gemini": gemini_models,
|
||||||
|
"fireworks_ai": fireworks_ai_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping for those models which have larger equivalents
|
# mapping for those models which have larger equivalents
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
|
@ -26,3 +28,26 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
||||||
extra_body["metadata"]["prompt"] = _prompt.__dict__
|
extra_body["metadata"]["prompt"] = _prompt.__dict__
|
||||||
|
|
||||||
return extra_body
|
return extra_body
|
||||||
|
|
||||||
|
|
||||||
|
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
||||||
|
"""
|
||||||
|
Pick a random model from the LLM provider.
|
||||||
|
"""
|
||||||
|
if custom_llm_provider not in litellm.models_by_provider:
|
||||||
|
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
|
||||||
|
|
||||||
|
known_models = litellm.models_by_provider.get(custom_llm_provider, [])
|
||||||
|
min_cost = float("inf")
|
||||||
|
cheapest_model = None
|
||||||
|
for model in known_models:
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
|
||||||
|
"output_cost_per_token", 0
|
||||||
|
)
|
||||||
|
if _cost < min_cost:
|
||||||
|
min_cost = _cost
|
||||||
|
cheapest_model = model
|
||||||
|
return cheapest_model
|
||||||
|
|
|
@ -5076,6 +5076,18 @@ async def ahealth_check(
|
||||||
model_params["prompt"] = prompt
|
model_params["prompt"] = prompt
|
||||||
await litellm.aimage_generation(**model_params)
|
await litellm.aimage_generation(**model_params)
|
||||||
response = {}
|
response = {}
|
||||||
|
elif "*" in model:
|
||||||
|
from litellm.litellm_core_utils.llm_request_utils import (
|
||||||
|
pick_cheapest_model_from_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is a wildcard model, we need to pick a random model from the provider
|
||||||
|
cheapest_model = pick_cheapest_model_from_llm_provider(
|
||||||
|
custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
model_params["model"] = cheapest_model
|
||||||
|
await acompletion(**model_params)
|
||||||
|
response = {} # args like remaining ratelimit etc.
|
||||||
else: # default to completion calls
|
else: # default to completion calls
|
||||||
await acompletion(**model_params)
|
await acompletion(**model_params)
|
||||||
response = {} # args like remaining ratelimit etc.
|
response = {} # args like remaining ratelimit etc.
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests if ahealth_check() actually works
|
# This tests if ahealth_check() actually works
|
||||||
|
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm, asyncio
|
import asyncio
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -105,3 +109,27 @@ async def test_sagemaker_embedding_health_check():
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_sagemaker_embedding_health_check())
|
# asyncio.run(test_sagemaker_embedding_health_check())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_health_check():
|
||||||
|
"""
|
||||||
|
This should not fail
|
||||||
|
|
||||||
|
ensure that provider wildcard model passes health check
|
||||||
|
"""
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"),
|
||||||
|
"model": "fireworks_ai/*",
|
||||||
|
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
|
||||||
|
},
|
||||||
|
mode=None,
|
||||||
|
prompt="What's 1 + 1?",
|
||||||
|
input=["test from litellm"],
|
||||||
|
default_timeout=6000,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert response == {}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue