fix(huggingface_restapi.py): fix hf streaming to raise exceptions

This commit is contained in:
Krrish Dholakia 2024-02-15 21:25:12 -08:00
parent 34fce00960
commit 1b844aafdc

View file

@ -49,9 +49,9 @@ class HuggingfaceConfig:
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
return_full_text: Optional[ return_full_text: Optional[bool] = (
bool False # by default don't return the input as part of the output
] = False # by default don't return the input as part of the output )
seed: Optional[int] = None seed: Optional[int] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_k: Optional[int] = None top_k: Optional[int] = None
@ -188,9 +188,9 @@ class Huggingface(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
} }
if api_key and headers is None: if api_key and headers is None:
default_headers[ default_headers["Authorization"] = (
"Authorization" f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens )
headers = default_headers headers = default_headers
elif headers: elif headers:
headers = headers headers = headers
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True "stream": (
if "stream" in optional_params and optional_params["stream"] == True True
else False, if "stream" in optional_params
and optional_params["stream"] == True
else False
),
} }
input_text = prompt input_text = prompt
else: else:
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, "parameters": inference_params,
"stream": True "stream": (
if "stream" in optional_params and optional_params["stream"] == True True
else False, if "stream" in optional_params
and optional_params["stream"] == True
else False
),
} }
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
input_text: str, input_text: str,
model: str, model: str,
optional_params: dict, optional_params: dict,
timeout: float timeout: float,
): ):
response = None response = None
try: try:
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post( response = await client.post(url=api_base, json=data, headers=headers)
url=api_base, json=data, headers=headers
)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise HuggingfaceError( raise HuggingfaceError(
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float timeout: float,
): ):
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
response = client.stream( response = client.stream(
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
) )
async with response as r: async with response as r:
if r.status_code != 200: if r.status_code != 200:
text = await r.aread()
raise HuggingfaceError( raise HuggingfaceError(
status_code=r.status_code, status_code=r.status_code,
message="An error occurred while streaming", message=str(text),
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(), completion_stream=r.aiter_lines(),
@ -625,9 +630,13 @@ class Huggingface(BaseLLM):
custom_llm_provider="huggingface", custom_llm_provider="huggingface",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
async def generator():
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
return generator()
def embedding( def embedding(
self, self,
model: str, model: str,