diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index eb8ce38b97..e945242bf2 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -399,9 +399,12 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": optional_params, - "stream": True - if "stream" in optional_params and optional_params["stream"] == True - else False, + "stream": ( + True + if "stream" in optional_params + and optional_params["stream"] == True + else False + ), } input_text = prompt else: @@ -430,9 +433,12 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": inference_params, - "stream": True - if "stream" in optional_params and optional_params["stream"] == True - else False, + "stream": ( + True + if "stream" in optional_params + and optional_params["stream"] == True + else False + ), } input_text = prompt ## LOGGING @@ -561,14 +567,12 @@ class Huggingface(BaseLLM): input_text: str, model: str, optional_params: dict, - timeout: float + timeout: float, ): response = None try: async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - url=api_base, json=data, headers=headers - ) + response = await client.post(url=api_base, json=data, headers=headers) response_json = response.json() if response.status_code != 200: raise HuggingfaceError( @@ -607,7 +611,7 @@ class Huggingface(BaseLLM): headers: dict, model_response: ModelResponse, model: str, - timeout: float + timeout: float, ): async with httpx.AsyncClient(timeout=timeout) as client: response = client.stream( @@ -615,9 +619,10 @@ class Huggingface(BaseLLM): ) async with response as r: if r.status_code != 200: + text = await r.aread() raise HuggingfaceError( status_code=r.status_code, - message="An error occurred while streaming", + message=str(text), ) streamwrapper = CustomStreamWrapper( completion_stream=r.aiter_lines(), diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index a5497b539c..30d777d799 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream(): if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") + except litellm.ServiceUnavailableError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 018308e563..f1289d8ff9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6975,6 +6975,14 @@ def exception_type( model=model, response=original_exception.response, ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) else: exception_mapping_worked = True raise APIError(