fix(huggingface_restapi.py): support timeouts for huggingface + openai text completions

https://github.com/BerriAI/litellm/issues/1334
This commit is contained in:
Krrish Dholakia 2024-01-08 11:40:56 +05:30
parent c720870f80
commit b1fd0a164b
5 changed files with 41 additions and 14 deletions

View file

@ -318,6 +318,7 @@ class Huggingface(BaseLLM):
headers: Optional[dict], headers: Optional[dict],
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
timeout: float,
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
@ -450,10 +451,10 @@ class Huggingface(BaseLLM):
if acompletion is True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
else: else:
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
### SYNC STREAMING ### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
@ -560,12 +561,13 @@ class Huggingface(BaseLLM):
input_text: str, input_text: str,
model: str, model: str,
optional_params: dict, optional_params: dict,
timeout: float
): ):
response = None response = None
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post( response = await client.post(
url=api_base, json=data, headers=headers, timeout=None url=api_base, json=data, headers=headers
) )
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -605,8 +607,9 @@ class Huggingface(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float
): ):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=timeout) as client:
response = client.stream( response = client.stream(
"POST", url=f"{api_base}", json=data, headers=headers "POST", url=f"{api_base}", json=data, headers=headers
) )
@ -616,7 +619,6 @@ class Huggingface(BaseLLM):
status_code=r.status_code, status_code=r.status_code,
message="An error occurred while streaming", message="An error occurred while streaming",
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(), completion_stream=r.aiter_lines(),
model=model, model=model,

View file

@ -836,6 +836,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str, model: str,
messages: list, messages: list,
timeout: float,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
@ -887,9 +888,10 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout
) )
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
@ -898,12 +900,14 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout
) )
else: else:
response = httpx.post( response = httpx.post(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers,
timeout=timeout
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError( raise OpenAIError(
@ -939,8 +943,9 @@ class OpenAITextCompletion(BaseLLM):
prompt: str, prompt: str,
api_key: str, api_key: str,
model: str, model: str,
timeout: float
): ):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=timeout) as client:
try: try:
response = await client.post( response = await client.post(
api_base, api_base,
@ -980,13 +985,14 @@ class OpenAITextCompletion(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float
): ):
with httpx.stream( with httpx.stream(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers,
method="POST", method="POST",
timeout=litellm.request_timeout, timeout=timeout,
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError( raise OpenAIError(
@ -1010,6 +1016,7 @@ class OpenAITextCompletion(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float
): ):
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( async with client.stream(
@ -1017,7 +1024,7 @@ class OpenAITextCompletion(BaseLLM):
json=data, json=data,
headers=headers, headers=headers,
method="POST", method="POST",
timeout=litellm.request_timeout, timeout=timeout,
) as response: ) as response:
try: try:
if response.status_code != 200: if response.status_code != 200:

View file

@ -814,6 +814,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion == True:
@ -1116,6 +1117,7 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout
) )
if ( if (
"stream" in optional_params "stream" in optional_params

View file

@ -197,6 +197,20 @@ def test_get_cloudflare_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
@pytest.mark.asyncio
async def test_hf_completion_tgi():
# litellm.set_verbose=True
try:
response = await acompletion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_get_cloudflare_response_streaming() # test_get_cloudflare_response_streaming()

View file

@ -923,10 +923,10 @@ def ai21_completion_call_bad_key():
# ai21_completion_call_bad_key() # ai21_completion_call_bad_key()
@pytest.mark.asyncio
def hf_test_completion_tgi_stream(): async def test_hf_completion_tgi_stream():
try: try:
response = completion( response = await acompletion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta", model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}], messages=[{"content": "Hello, how are you?", "role": "user"}],
stream=True, stream=True,
@ -935,11 +935,13 @@ def hf_test_completion_tgi_stream():
print(f"response: {response}") print(f"response: {response}")
complete_response = "" complete_response = ""
start_time = time.time() start_time = time.time()
for idx, chunk in enumerate(response): idx = 0
async for chunk in response:
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
complete_response += chunk complete_response += chunk
if finished: if finished:
break break
idx += 1
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")