mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(test_streaming.py): handle hf tgi zephyr not loading for streaming issue
This commit is contained in:
parent
31c8d62ac2
commit
eb45df16f1
3 changed files with 27 additions and 12 deletions
|
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
input_text = prompt
|
||||
else:
|
||||
|
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
|
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
timeout: float
|
||||
timeout: float,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url=api_base, json=data, headers=headers
|
||||
)
|
||||
response = await client.post(url=api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(
|
||||
|
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
timeout: float,
|
||||
):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = client.stream(
|
||||
|
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
async with response as r:
|
||||
if r.status_code != 200:
|
||||
text = await r.aread()
|
||||
raise HuggingfaceError(
|
||||
status_code=r.status_code,
|
||||
message="An error occurred while streaming",
|
||||
message=str(text),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
|
|
|
@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream():
|
|||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -6975,6 +6975,14 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif original_exception.status_code == 503:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
llm_provider="huggingface",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue