diff --git a/litellm/__init__.py b/litellm/__init__.py index 27b4fc4408..74a93b7666 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -363,6 +363,7 @@ ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] bedrock_models: List = [] +fireworks_ai_models: List = [] deepinfra_models: List = [] perplexity_models: List = [] watsonx_models: List = [] @@ -423,6 +424,8 @@ for key, value in model_cost.items(): watsonx_models.append(key) elif value.get("litellm_provider") == "gemini": gemini_models.append(key) + elif value.get("litellm_provider") == "fireworks_ai": + fireworks_ai_models.append(key) # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ "api.perplexity.ai", @@ -726,6 +729,7 @@ models_by_provider: dict = { "maritalk": maritalk_models, "watsonx": watsonx_models, "gemini": gemini_models, + "fireworks_ai": fireworks_ai_models, } # mapping for those models which have larger equivalents diff --git a/litellm/litellm_core_utils/llm_request_utils.py b/litellm/litellm_core_utils/llm_request_utils.py index 557d73b0ab..ab0c231123 100644 --- a/litellm/litellm_core_utils/llm_request_utils.py +++ b/litellm/litellm_core_utils/llm_request_utils.py @@ -1,5 +1,7 @@ from typing import Dict, Optional +import litellm + def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]: """ @@ -26,3 +28,26 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]: extra_body["metadata"]["prompt"] = _prompt.__dict__ return extra_body + + +def pick_cheapest_model_from_llm_provider(custom_llm_provider: str): + """ + Pick a random model from the LLM provider. + """ + if custom_llm_provider not in litellm.models_by_provider: + raise ValueError(f"Unknown LLM provider: {custom_llm_provider}") + + known_models = litellm.models_by_provider.get(custom_llm_provider, []) + min_cost = float("inf") + cheapest_model = None + for model in known_models: + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + _cost = model_info.get("input_cost_per_token", 0) + model_info.get( + "output_cost_per_token", 0 + ) + if _cost < min_cost: + min_cost = _cost + cheapest_model = model + return cheapest_model diff --git a/litellm/main.py b/litellm/main.py index f20cfa9966..ca9d145f1b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5076,6 +5076,18 @@ async def ahealth_check( model_params["prompt"] = prompt await litellm.aimage_generation(**model_params) response = {} + elif "*" in model: + from litellm.litellm_core_utils.llm_request_utils import ( + pick_cheapest_model_from_llm_provider, + ) + + # this is a wildcard model, we need to pick a random model from the provider + cheapest_model = pick_cheapest_model_from_llm_provider( + custom_llm_provider=custom_llm_provider + ) + model_params["model"] = cheapest_model + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. else: # default to completion calls await acompletion(**model_params) response = {} # args like remaining ratelimit etc. diff --git a/litellm/tests/test_health_check.py b/litellm/tests/test_health_check.py index f632e76921..75f40541a7 100644 --- a/litellm/tests/test_health_check.py +++ b/litellm/tests/test_health_check.py @@ -1,14 +1,18 @@ #### What this tests #### # This tests if ahealth_check() actually works -import sys, os +import os +import sys import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import litellm, asyncio +import asyncio + +import litellm @pytest.mark.asyncio @@ -105,3 +109,27 @@ async def test_sagemaker_embedding_health_check(): # asyncio.run(test_sagemaker_embedding_health_check()) + + +@pytest.mark.asyncio +async def test_fireworks_health_check(): + """ + This should not fail + + ensure that provider wildcard model passes health check + """ + response = await litellm.ahealth_check( + model_params={ + "api_key": os.environ.get("FIREWORKS_AI_API_KEY"), + "model": "fireworks_ai/*", + "messages": [{"role": "user", "content": "What's 1 + 1?"}], + }, + mode=None, + prompt="What's 1 + 1?", + input=["test from litellm"], + default_timeout=6000, + ) + print(f"response: {response}") + assert response == {} + + return response