diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 205bad7ee..61a6ac040 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -634,15 +634,60 @@ class Huggingface(BaseLLM): status_code=r.status_code, message=str(text), ) - streamwrapper = CustomStreamWrapper( - completion_stream=r.aiter_lines(), + """ + Check first chunk for error message. + If error message, raise error. + If not - add back to stream + """ + # Async iterator over the lines in the response body + response_iterator = r.aiter_lines() + + # Attempt to get the first line/chunk from the response + try: + first_chunk = await response_iterator.__anext__() + except StopAsyncIteration: + # Handle the case where there are no lines to read (empty response) + first_chunk = "" + + # Check the first chunk for an error message + if ( + "error" in first_chunk.lower() + ): # Adjust this condition based on how error messages are structured + raise HuggingfaceError( + status_code=400, + message=first_chunk, + ) + + return self.async_streaming_generator( + first_chunk=first_chunk, + response_iterator=response_iterator, model=model, - custom_llm_provider="huggingface", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + async def async_streaming_generator( + self, first_chunk, response_iterator, model, logging_obj + ): + # Create a new async generator that begins with the first_chunk and includes the remaining items + async def custom_stream_with_first_chunk(): + yield first_chunk # Yield back the first chunk + async for ( + chunk + ) in response_iterator: # Continue yielding the rest of the chunks + yield chunk + + # Creating a new completion stream that starts with the first chunk + completion_stream = custom_stream_with_first_chunk() + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="huggingface", + logging_obj=logging_obj, + ) + + async for transformed_chunk in streamwrapper: + yield transformed_chunk def embedding( self, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f1dec3881..7dbb068bc 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -167,6 +167,15 @@ class ProxyException(Exception): self.param = param self.code = code + def to_dict(self) -> dict: + """Converts the ProxyException instance to a dictionary.""" + return { + "message": self.message, + "type": self.type, + "param": self.param, + "code": self.code, + } + @app.exception_handler(ProxyException) async def openai_exception_handler(request: Request, exc: ProxyException): @@ -2241,12 +2250,14 @@ async def async_data_generator(response, user_api_key_dict): error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" - raise ProxyException( + proxy_exception = ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) + error_returned = json.dumps({"error": proxy_exception.to_dict()}) + yield f"data: {error_returned}\n\n" def select_data_generator(response, user_api_key_dict): diff --git a/litellm/utils.py b/litellm/utils.py index 53e6e8245..233fd6bae 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8117,7 +8117,8 @@ class CustomStreamWrapper: text = "" # don't return the final bos token is_finished = True finish_reason = "stop" - + elif data_json.get("error", False): + raise Exception(data_json.get("error")) return { "text": text, "is_finished": is_finished, @@ -8132,7 +8133,7 @@ class CustomStreamWrapper: } except Exception as e: traceback.print_exc() - # raise(e) + raise e def handle_ai21_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8")