fix(huggingface_restapi.py): fix hf streaming to raise exceptions

This commit is contained in:
Krrish Dholakia 2024-02-15 21:25:12 -08:00
parent 34fce00960
commit 1b844aafdc

View file

@ -49,9 +49,9 @@ class HuggingfaceConfig:
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -188,9 +188,9 @@ class Huggingface(BaseLLM):
"content-type": "application/json",
}
if api_key and headers is None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
default_headers["Authorization"] = (
f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
)
headers = default_headers
elif headers:
headers = headers
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
input_text = prompt
else:
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
input_text = prompt
## LOGGING
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
input_text: str,
model: str,
optional_params: dict,
timeout: float
timeout: float,
):
response = None
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url=api_base, json=data, headers=headers
)
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float
timeout: float,
):
async with httpx.AsyncClient(timeout=timeout) as client:
response = client.stream(
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
)
async with response as r:
if r.status_code != 200:
text = await r.aread()
raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
message=str(text),
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
@ -625,8 +630,12 @@ class Huggingface(BaseLLM):
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
async def generator():
async for transformed_chunk in streamwrapper:
yield transformed_chunk
return generator()
def embedding(
self,