fix(huggingface_restapi.py): fix huggingface streaming error raising

This commit is contained in:
Krrish Dholakia 2024-03-04 09:32:27 -08:00
parent 766e8cba84
commit 873ddde924
3 changed files with 65 additions and 8 deletions

View file

@ -634,8 +634,53 @@ class Huggingface(BaseLLM):
status_code=r.status_code,
message=str(text),
)
"""
Check first chunk for error message.
If error message, raise error.
If not - add back to stream
"""
# Async iterator over the lines in the response body
response_iterator = r.aiter_lines()
# Attempt to get the first line/chunk from the response
try:
first_chunk = await response_iterator.__anext__()
except StopAsyncIteration:
# Handle the case where there are no lines to read (empty response)
first_chunk = ""
# Check the first chunk for an error message
if (
"error" in first_chunk.lower()
): # Adjust this condition based on how error messages are structured
raise HuggingfaceError(
status_code=400,
message=first_chunk,
)
return self.async_streaming_generator(
first_chunk=first_chunk,
response_iterator=response_iterator,
model=model,
logging_obj=logging_obj,
)
async def async_streaming_generator(
self, first_chunk, response_iterator, model, logging_obj
):
# Create a new async generator that begins with the first_chunk and includes the remaining items
async def custom_stream_with_first_chunk():
yield first_chunk # Yield back the first chunk
async for (
chunk
) in response_iterator: # Continue yielding the rest of the chunks
yield chunk
# Creating a new completion stream that starts with the first chunk
completion_stream = custom_stream_with_first_chunk()
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
completion_stream=completion_stream,
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,

View file

@ -167,6 +167,15 @@ class ProxyException(Exception):
self.param = param
self.code = code
def to_dict(self) -> dict:
"""Converts the ProxyException instance to a dictionary."""
return {
"message": self.message,
"type": self.type,
"param": self.param,
"code": self.code,
}
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
@ -2241,12 +2250,14 @@ async def async_data_generator(response, user_api_key_dict):
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
def select_data_generator(response, user_api_key_dict):

View file

@ -8117,7 +8117,8 @@ class CustomStreamWrapper:
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
@ -8132,7 +8133,7 @@ class CustomStreamWrapper:
}
except Exception as e:
traceback.print_exc()
# raise(e)
raise e
def handle_ai21_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")