Fix Fireworks AI 429 error mapping to use RateLimitError instead of BadRequestError

This commit is contained in:
openhands 2025-04-05 21:06:48 +00:00
parent 7262606411
commit cf8d75b7fd
2 changed files with 61 additions and 0 deletions

View file

@ -20,6 +20,15 @@ class FireworksAIMixin:
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers] self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException: ) -> BaseLLMException:
# Check if it's a rate limit error (status code 429)
if status_code == 429:
from litellm.exceptions import RateLimitError
return RateLimitError(
message=f"Fireworks_aiException - {error_message}",
llm_provider="fireworks_ai",
model="", # This will be set later in the exception mapping
response=None, # This will be set later in the exception mapping
)
return FireworksAIException( return FireworksAIException(
status_code=status_code, status_code=status_code,
message=error_message, message=error_message,

View file

@ -0,0 +1,52 @@
import os
import sys
import unittest
from unittest.mock import patch, MagicMock
import pytest
import httpx
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.fireworks_ai.common_utils import FireworksAIMixin
from litellm.exceptions import RateLimitError, BadRequestError
class TestFireworksAIExceptions(unittest.TestCase):
def setUp(self):
self.fireworks_mixin = FireworksAIMixin()
def test_rate_limit_error(self):
"""Test that a 429 error is properly translated to a RateLimitError"""
error_message = "server overloaded, please try again later"
status_code = 429
headers = {}
exception = self.fireworks_mixin.get_error_class(
error_message=error_message,
status_code=status_code,
headers=headers,
)
self.assertIsInstance(exception, RateLimitError)
self.assertEqual(exception.llm_provider, "fireworks_ai")
self.assertIn("Fireworks_aiException", exception.message)
self.assertIn(error_message, exception.message)
def test_bad_request_error(self):
"""Test that a 400 error is properly translated to a FireworksAIException"""
error_message = "Invalid request"
status_code = 400
headers = {}
exception = self.fireworks_mixin.get_error_class(
error_message=error_message,
status_code=status_code,
headers=headers,
)
self.assertEqual(exception.status_code, 400)
self.assertEqual(exception.message, error_message)