forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): support timeouts for huggingface + openai text completions
https://github.com/BerriAI/litellm/issues/1334
This commit is contained in:
parent
c720870f80
commit
b1fd0a164b
5 changed files with 41 additions and 14 deletions
|
@ -318,6 +318,7 @@ class Huggingface(BaseLLM):
|
|||
headers: Optional[dict],
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
timeout: float,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
|
@ -450,10 +451,10 @@ class Huggingface(BaseLLM):
|
|||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
||||
### SYNC STREAMING
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
|
@ -560,12 +561,13 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
timeout: float
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url=api_base, json=data, headers=headers, timeout=None
|
||||
url=api_base, json=data, headers=headers
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
|
@ -605,8 +607,9 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = client.stream(
|
||||
"POST", url=f"{api_base}", json=data, headers=headers
|
||||
)
|
||||
|
@ -616,7 +619,6 @@ class Huggingface(BaseLLM):
|
|||
status_code=r.status_code,
|
||||
message="An error occurred while streaming",
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
model=model,
|
||||
|
|
|
@ -836,6 +836,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
api_key: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
timeout: float,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
|
@ -887,9 +888,10 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
|
@ -898,12 +900,14 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
response = httpx.post(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
|
@ -939,8 +943,9 @@ class OpenAITextCompletion(BaseLLM):
|
|||
prompt: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float
|
||||
):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base,
|
||||
|
@ -980,13 +985,14 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
):
|
||||
with httpx.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
|
@ -1010,6 +1016,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
|
@ -1017,7 +1024,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
|
|
|
@ -814,6 +814,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
|
@ -1116,6 +1117,7 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
|
|
@ -197,6 +197,20 @@ def test_get_cloudflare_response_streaming():
|
|||
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hf_completion_tgi():
|
||||
# litellm.set_verbose=True
|
||||
try:
|
||||
response = await acompletion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_get_cloudflare_response_streaming()
|
||||
|
||||
|
|
|
@ -923,10 +923,10 @@ def ai21_completion_call_bad_key():
|
|||
|
||||
# ai21_completion_call_bad_key()
|
||||
|
||||
|
||||
def hf_test_completion_tgi_stream():
|
||||
@pytest.mark.asyncio
|
||||
async def test_hf_completion_tgi_stream():
|
||||
try:
|
||||
response = completion(
|
||||
response = await acompletion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
stream=True,
|
||||
|
@ -935,11 +935,13 @@ def hf_test_completion_tgi_stream():
|
|||
print(f"response: {response}")
|
||||
complete_response = ""
|
||||
start_time = time.time()
|
||||
for idx, chunk in enumerate(response):
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
complete_response += chunk
|
||||
if finished:
|
||||
break
|
||||
idx += 1
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue