forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): fix huggingface streaming error raising
This commit is contained in:
parent
766e8cba84
commit
873ddde924
3 changed files with 65 additions and 8 deletions
|
@ -634,15 +634,60 @@ class Huggingface(BaseLLM):
|
||||||
status_code=r.status_code,
|
status_code=r.status_code,
|
||||||
message=str(text),
|
message=str(text),
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
"""
|
||||||
completion_stream=r.aiter_lines(),
|
Check first chunk for error message.
|
||||||
|
If error message, raise error.
|
||||||
|
If not - add back to stream
|
||||||
|
"""
|
||||||
|
# Async iterator over the lines in the response body
|
||||||
|
response_iterator = r.aiter_lines()
|
||||||
|
|
||||||
|
# Attempt to get the first line/chunk from the response
|
||||||
|
try:
|
||||||
|
first_chunk = await response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# Handle the case where there are no lines to read (empty response)
|
||||||
|
first_chunk = ""
|
||||||
|
|
||||||
|
# Check the first chunk for an error message
|
||||||
|
if (
|
||||||
|
"error" in first_chunk.lower()
|
||||||
|
): # Adjust this condition based on how error messages are structured
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=400,
|
||||||
|
message=first_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_streaming_generator(
|
||||||
|
first_chunk=first_chunk,
|
||||||
|
response_iterator=response_iterator,
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="huggingface",
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for transformed_chunk in streamwrapper:
|
async def async_streaming_generator(
|
||||||
yield transformed_chunk
|
self, first_chunk, response_iterator, model, logging_obj
|
||||||
|
):
|
||||||
|
# Create a new async generator that begins with the first_chunk and includes the remaining items
|
||||||
|
async def custom_stream_with_first_chunk():
|
||||||
|
yield first_chunk # Yield back the first chunk
|
||||||
|
async for (
|
||||||
|
chunk
|
||||||
|
) in response_iterator: # Continue yielding the rest of the chunks
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Creating a new completion stream that starts with the first chunk
|
||||||
|
completion_stream = custom_stream_with_first_chunk()
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="huggingface",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -167,6 +167,15 @@ class ProxyException(Exception):
|
||||||
self.param = param
|
self.param = param
|
||||||
self.code = code
|
self.code = code
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Converts the ProxyException instance to a dictionary."""
|
||||||
|
return {
|
||||||
|
"message": self.message,
|
||||||
|
"type": self.type,
|
||||||
|
"param": self.param,
|
||||||
|
"code": self.code,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ProxyException)
|
@app.exception_handler(ProxyException)
|
||||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||||
|
@ -2241,12 +2250,14 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
|
||||||
raise ProxyException(
|
proxy_exception = ProxyException(
|
||||||
message=getattr(e, "message", error_msg),
|
message=getattr(e, "message", error_msg),
|
||||||
type=getattr(e, "type", "None"),
|
type=getattr(e, "type", "None"),
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", 500),
|
code=getattr(e, "status_code", 500),
|
||||||
)
|
)
|
||||||
|
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||||
|
yield f"data: {error_returned}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def select_data_generator(response, user_api_key_dict):
|
def select_data_generator(response, user_api_key_dict):
|
||||||
|
|
|
@ -8117,7 +8117,8 @@ class CustomStreamWrapper:
|
||||||
text = "" # don't return the final bos token
|
text = "" # don't return the final bos token
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
elif data_json.get("error", False):
|
||||||
|
raise Exception(data_json.get("error"))
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
|
@ -8132,7 +8133,7 @@ class CustomStreamWrapper:
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# raise(e)
|
raise e
|
||||||
|
|
||||||
def handle_ai21_chunk(self, chunk): # fake streaming
|
def handle_ai21_chunk(self, chunk): # fake streaming
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue