Merge pull request #5431 from BerriAI/litellm_Add_fireworks_ai_health_check

[Fix-Proxy] /health check for provider wildcard models (fireworks/*)
This commit is contained in:
Ishaan Jaff 2024-08-29 14:25:05 -07:00 committed by GitHub
commit 5851a8f901
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 71 additions and 2 deletions

View file

@ -363,6 +363,7 @@ ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
bedrock_models: List = [] bedrock_models: List = []
fireworks_ai_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
watsonx_models: List = [] watsonx_models: List = []
@ -423,6 +424,8 @@ for key, value in model_cost.items():
watsonx_models.append(key) watsonx_models.append(key)
elif value.get("litellm_provider") == "gemini": elif value.get("litellm_provider") == "gemini":
gemini_models.append(key) gemini_models.append(key)
elif value.get("litellm_provider") == "fireworks_ai":
fireworks_ai_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
@ -726,6 +729,7 @@ models_by_provider: dict = {
"maritalk": maritalk_models, "maritalk": maritalk_models,
"watsonx": watsonx_models, "watsonx": watsonx_models,
"gemini": gemini_models, "gemini": gemini_models,
"fireworks_ai": fireworks_ai_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents

View file

@ -1,5 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
import litellm
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]: def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
""" """
@ -26,3 +28,26 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
extra_body["metadata"]["prompt"] = _prompt.__dict__ extra_body["metadata"]["prompt"] = _prompt.__dict__
return extra_body return extra_body
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
"""
Pick a random model from the LLM provider.
"""
if custom_llm_provider not in litellm.models_by_provider:
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
known_models = litellm.models_by_provider.get(custom_llm_provider, [])
min_cost = float("inf")
cheapest_model = None
for model in known_models:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
"output_cost_per_token", 0
)
if _cost < min_cost:
min_cost = _cost
cheapest_model = model
return cheapest_model

View file

@ -5076,6 +5076,18 @@ async def ahealth_check(
model_params["prompt"] = prompt model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params) await litellm.aimage_generation(**model_params)
response = {} response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
model_params["model"] = cheapest_model
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
else: # default to completion calls else: # default to completion calls
await acompletion(**model_params) await acompletion(**model_params)
response = {} # args like remaining ratelimit etc. response = {} # args like remaining ratelimit etc.

View file

@ -1,14 +1,18 @@
#### What this tests #### #### What this tests ####
# This tests if ahealth_check() actually works # This tests if ahealth_check() actually works
import sys, os import os
import sys
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm, asyncio import asyncio
import litellm
@pytest.mark.asyncio @pytest.mark.asyncio
@ -105,3 +109,27 @@ async def test_sagemaker_embedding_health_check():
# asyncio.run(test_sagemaker_embedding_health_check()) # asyncio.run(test_sagemaker_embedding_health_check())
@pytest.mark.asyncio
async def test_fireworks_health_check():
"""
This should not fail
ensure that provider wildcard model passes health check
"""
response = await litellm.ahealth_check(
model_params={
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"),
"model": "fireworks_ai/*",
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
},
mode=None,
prompt="What's 1 + 1?",
input=["test from litellm"],
default_timeout=6000,
)
print(f"response: {response}")
assert response == {}
return response