mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(test_streaming.py): handle hf tgi zephyr not loading for streaming issue
This commit is contained in:
parent
31c8d62ac2
commit
eb45df16f1
3 changed files with 27 additions and 12 deletions
|
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": True
|
"stream": (
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
True
|
||||||
else False,
|
if "stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
else False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
else:
|
else:
|
||||||
|
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
"parameters": inference_params,
|
||||||
"stream": True
|
"stream": (
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
True
|
||||||
else False,
|
if "stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
else False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
timeout: float
|
timeout: float,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = await client.post(
|
response = await client.post(url=api_base, json=data, headers=headers)
|
||||||
url=api_base, json=data, headers=headers
|
|
||||||
)
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
|
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float
|
timeout: float,
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = client.stream(
|
response = client.stream(
|
||||||
|
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
async with response as r:
|
async with response as r:
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
|
text = await r.aread()
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
status_code=r.status_code,
|
status_code=r.status_code,
|
||||||
message="An error occurred while streaming",
|
message=str(text),
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=r.aiter_lines(),
|
completion_stream=r.aiter_lines(),
|
||||||
|
|
|
@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream():
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
|
except litellm.ServiceUnavailableError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -6975,6 +6975,14 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 503:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message=f"HuggingfaceException - {original_exception.message}",
|
||||||
|
llm_provider="huggingface",
|
||||||
|
model=model,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue