fix(huggingface_restapi.py): support timeouts for huggingface + openai text completions

https://github.com/BerriAI/litellm/issues/1334
This commit is contained in:
Krrish Dholakia 2024-01-08 11:40:56 +05:30
parent c720870f80
commit b1fd0a164b
5 changed files with 41 additions and 14 deletions

View file

@ -318,6 +318,7 @@ class Huggingface(BaseLLM):
headers: Optional[dict],
model_response: ModelResponse,
print_verbose: Callable,
timeout: float,
encoding,
api_key,
logging_obj,
@ -450,10 +451,10 @@ class Huggingface(BaseLLM):
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
@ -560,12 +561,13 @@ class Huggingface(BaseLLM):
input_text: str,
model: str,
optional_params: dict,
timeout: float
):
response = None
try:
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
url=api_base, json=data, headers=headers
)
response_json = response.json()
if response.status_code != 200:
@ -605,8 +607,9 @@ class Huggingface(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float
):
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=timeout) as client:
response = client.stream(
"POST", url=f"{api_base}", json=data, headers=headers
)
@ -616,7 +619,6 @@ class Huggingface(BaseLLM):
status_code=r.status_code,
message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,