forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): support timeouts for huggingface + openai text completions
https://github.com/BerriAI/litellm/issues/1334
This commit is contained in:
parent
c720870f80
commit
b1fd0a164b
5 changed files with 41 additions and 14 deletions
|
@ -318,6 +318,7 @@ class Huggingface(BaseLLM):
|
|||
headers: Optional[dict],
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
timeout: float,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
|
@ -450,10 +451,10 @@ class Huggingface(BaseLLM):
|
|||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
||||
### SYNC STREAMING
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
|
@ -560,12 +561,13 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
timeout: float
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url=api_base, json=data, headers=headers, timeout=None
|
||||
url=api_base, json=data, headers=headers
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
|
@ -605,8 +607,9 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = client.stream(
|
||||
"POST", url=f"{api_base}", json=data, headers=headers
|
||||
)
|
||||
|
@ -616,7 +619,6 @@ class Huggingface(BaseLLM):
|
|||
status_code=r.status_code,
|
||||
message="An error occurred while streaming",
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
model=model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue