diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 6f53b0f8f..d3b4302ac 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -25,6 +25,15 @@ def test_get_llm_provider(): # test_get_llm_provider() +def test_get_llm_provider_fireworks(): # tests finetuned fireworks models - https://github.com/BerriAI/litellm/issues/4923 + model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model="fireworks_ai/accounts/my-test-1234" + ) + + assert custom_llm_provider == "fireworks_ai" + assert model == "accounts/my-test-1234" + + def test_get_llm_provider_catch_all(): _, response, _, _ = litellm.get_llm_provider(model="*") assert response == "openai" diff --git a/litellm/utils.py b/litellm/utils.py index 780148059..4e3a4e60a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4463,7 +4463,7 @@ def get_llm_provider( dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY") elif custom_llm_provider == "fireworks_ai": # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 - if not model.startswith("accounts/fireworks/models"): + if not model.startswith("accounts/"): model = f"accounts/fireworks/models/{model}" api_base = api_base or "https://api.fireworks.ai/inference/v1" dynamic_api_key = api_key or (