mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(huggingface_restapi.py): fix huggingface streaming error raising
This commit is contained in:
parent
766e8cba84
commit
873ddde924
3 changed files with 65 additions and 8 deletions
|
@ -634,8 +634,53 @@ class Huggingface(BaseLLM):
|
|||
status_code=r.status_code,
|
||||
message=str(text),
|
||||
)
|
||||
"""
|
||||
Check first chunk for error message.
|
||||
If error message, raise error.
|
||||
If not - add back to stream
|
||||
"""
|
||||
# Async iterator over the lines in the response body
|
||||
response_iterator = r.aiter_lines()
|
||||
|
||||
# Attempt to get the first line/chunk from the response
|
||||
try:
|
||||
first_chunk = await response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# Handle the case where there are no lines to read (empty response)
|
||||
first_chunk = ""
|
||||
|
||||
# Check the first chunk for an error message
|
||||
if (
|
||||
"error" in first_chunk.lower()
|
||||
): # Adjust this condition based on how error messages are structured
|
||||
raise HuggingfaceError(
|
||||
status_code=400,
|
||||
message=first_chunk,
|
||||
)
|
||||
|
||||
return self.async_streaming_generator(
|
||||
first_chunk=first_chunk,
|
||||
response_iterator=response_iterator,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_streaming_generator(
|
||||
self, first_chunk, response_iterator, model, logging_obj
|
||||
):
|
||||
# Create a new async generator that begins with the first_chunk and includes the remaining items
|
||||
async def custom_stream_with_first_chunk():
|
||||
yield first_chunk # Yield back the first chunk
|
||||
async for (
|
||||
chunk
|
||||
) in response_iterator: # Continue yielding the rest of the chunks
|
||||
yield chunk
|
||||
|
||||
# Creating a new completion stream that starts with the first chunk
|
||||
completion_stream = custom_stream_with_first_chunk()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="huggingface",
|
||||
logging_obj=logging_obj,
|
||||
|
|
|
@ -167,6 +167,15 @@ class ProxyException(Exception):
|
|||
self.param = param
|
||||
self.code = code
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Converts the ProxyException instance to a dictionary."""
|
||||
return {
|
||||
"message": self.message,
|
||||
"type": self.type,
|
||||
"param": self.param,
|
||||
"code": self.code,
|
||||
}
|
||||
|
||||
|
||||
@app.exception_handler(ProxyException)
|
||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||
|
@ -2241,12 +2250,14 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
raise ProxyException(
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
def select_data_generator(response, user_api_key_dict):
|
||||
|
|
|
@ -8117,7 +8117,8 @@ class CustomStreamWrapper:
|
|||
text = "" # don't return the final bos token
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
|
||||
elif data_json.get("error", False):
|
||||
raise Exception(data_json.get("error"))
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
|
@ -8132,7 +8133,7 @@ class CustomStreamWrapper:
|
|||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# raise(e)
|
||||
raise e
|
||||
|
||||
def handle_ai21_chunk(self, chunk): # fake streaming
|
||||
chunk = chunk.decode("utf-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue