diff --git a/litellm/llms/fireworks_ai/common_utils.py b/litellm/llms/fireworks_ai/common_utils.py index 293403b133..0914af8c9a 100644 --- a/litellm/llms/fireworks_ai/common_utils.py +++ b/litellm/llms/fireworks_ai/common_utils.py @@ -20,6 +20,15 @@ class FireworksAIMixin: def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, Headers] ) -> BaseLLMException: + # Check if it's a rate limit error (status code 429) + if status_code == 429: + from litellm.exceptions import RateLimitError + return RateLimitError( + message=f"Fireworks_aiException - {error_message}", + llm_provider="fireworks_ai", + model="", # This will be set later in the exception mapping + response=None, # This will be set later in the exception mapping + ) return FireworksAIException( status_code=status_code, message=error_message, diff --git a/tests/test_fireworks_ai/test_fireworks_ai_exceptions.py b/tests/test_fireworks_ai/test_fireworks_ai_exceptions.py new file mode 100644 index 0000000000..119c6c2400 --- /dev/null +++ b/tests/test_fireworks_ai/test_fireworks_ai_exceptions.py @@ -0,0 +1,52 @@ +import os +import sys +import unittest +from unittest.mock import patch, MagicMock + +import pytest +import httpx + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import litellm +from litellm.llms.fireworks_ai.common_utils import FireworksAIMixin +from litellm.exceptions import RateLimitError, BadRequestError + + +class TestFireworksAIExceptions(unittest.TestCase): + def setUp(self): + self.fireworks_mixin = FireworksAIMixin() + + def test_rate_limit_error(self): + """Test that a 429 error is properly translated to a RateLimitError""" + error_message = "server overloaded, please try again later" + status_code = 429 + headers = {} + + exception = self.fireworks_mixin.get_error_class( + error_message=error_message, + status_code=status_code, + headers=headers, + ) + + self.assertIsInstance(exception, RateLimitError) + self.assertEqual(exception.llm_provider, "fireworks_ai") + self.assertIn("Fireworks_aiException", exception.message) + self.assertIn(error_message, exception.message) + + def test_bad_request_error(self): + """Test that a 400 error is properly translated to a FireworksAIException""" + error_message = "Invalid request" + status_code = 400 + headers = {} + + exception = self.fireworks_mixin.get_error_class( + error_message=error_message, + status_code=status_code, + headers=headers, + ) + + self.assertEqual(exception.status_code, 400) + self.assertEqual(exception.message, error_message) \ No newline at end of file