forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): fix hf streaming to raise exceptions
This commit is contained in:
parent
34fce00960
commit
1b844aafdc
1 changed files with 29 additions and 20 deletions
|
@ -49,9 +49,9 @@ class HuggingfaceConfig:
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: Optional[int] = None
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
return_full_text: Optional[
|
return_full_text: Optional[bool] = (
|
||||||
bool
|
False # by default don't return the input as part of the output
|
||||||
] = False # by default don't return the input as part of the output
|
)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
@ -188,9 +188,9 @@ class Huggingface(BaseLLM):
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
if api_key and headers is None:
|
if api_key and headers is None:
|
||||||
default_headers[
|
default_headers["Authorization"] = (
|
||||||
"Authorization"
|
f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
||||||
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
)
|
||||||
headers = default_headers
|
headers = default_headers
|
||||||
elif headers:
|
elif headers:
|
||||||
headers = headers
|
headers = headers
|
||||||
|
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": True
|
"stream": (
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
True
|
||||||
else False,
|
if "stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
else False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
else:
|
else:
|
||||||
|
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
"parameters": inference_params,
|
||||||
"stream": True
|
"stream": (
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
True
|
||||||
else False,
|
if "stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
else False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
timeout: float
|
timeout: float,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = await client.post(
|
response = await client.post(url=api_base, json=data, headers=headers)
|
||||||
url=api_base, json=data, headers=headers
|
|
||||||
)
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
|
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float
|
timeout: float,
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = client.stream(
|
response = client.stream(
|
||||||
|
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
async with response as r:
|
async with response as r:
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
|
text = await r.aread()
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
status_code=r.status_code,
|
status_code=r.status_code,
|
||||||
message="An error occurred while streaming",
|
message=str(text),
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=r.aiter_lines(),
|
completion_stream=r.aiter_lines(),
|
||||||
|
@ -625,8 +630,12 @@ class Huggingface(BaseLLM):
|
||||||
custom_llm_provider="huggingface",
|
custom_llm_provider="huggingface",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
async for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
async def generator():
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
|
||||||
|
return generator()
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue