forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): support timeouts for huggingface + openai text completions
https://github.com/BerriAI/litellm/issues/1334
This commit is contained in:
parent
c720870f80
commit
b1fd0a164b
5 changed files with 41 additions and 14 deletions
|
@ -318,6 +318,7 @@ class Huggingface(BaseLLM):
|
||||||
headers: Optional[dict],
|
headers: Optional[dict],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
|
timeout: float,
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -450,10 +451,10 @@ class Huggingface(BaseLLM):
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
|
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
|
||||||
else:
|
else:
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
|
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -560,12 +561,13 @@ class Huggingface(BaseLLM):
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
timeout: float
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
url=api_base, json=data, headers=headers, timeout=None
|
url=api_base, json=data, headers=headers
|
||||||
)
|
)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -605,8 +607,9 @@ class Huggingface(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
|
timeout: float
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
response = client.stream(
|
response = client.stream(
|
||||||
"POST", url=f"{api_base}", json=data, headers=headers
|
"POST", url=f"{api_base}", json=data, headers=headers
|
||||||
)
|
)
|
||||||
|
@ -616,7 +619,6 @@ class Huggingface(BaseLLM):
|
||||||
status_code=r.status_code,
|
status_code=r.status_code,
|
||||||
message="An error occurred while streaming",
|
message="An error occurred while streaming",
|
||||||
)
|
)
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=r.aiter_lines(),
|
completion_stream=r.aiter_lines(),
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -836,6 +836,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
timeout: float,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
|
@ -887,9 +888,10 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
|
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -898,12 +900,14 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = httpx.post(
|
response = httpx.post(
|
||||||
url=f"{api_base}",
|
url=f"{api_base}",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=timeout
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
|
@ -939,8 +943,9 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
timeout: float
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
api_base,
|
api_base,
|
||||||
|
@ -980,13 +985,14 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
|
timeout: float
|
||||||
):
|
):
|
||||||
with httpx.stream(
|
with httpx.stream(
|
||||||
url=f"{api_base}",
|
url=f"{api_base}",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
method="POST",
|
method="POST",
|
||||||
timeout=litellm.request_timeout,
|
timeout=timeout,
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
|
@ -1010,6 +1016,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
|
timeout: float
|
||||||
):
|
):
|
||||||
client = httpx.AsyncClient()
|
client = httpx.AsyncClient()
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
|
@ -1017,7 +1024,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
method="POST",
|
method="POST",
|
||||||
timeout=litellm.request_timeout,
|
timeout=timeout,
|
||||||
) as response:
|
) as response:
|
||||||
try:
|
try:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|
|
@ -814,6 +814,7 @@ def completion(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
|
@ -1116,6 +1117,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
timeout=timeout
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
|
|
@ -197,6 +197,20 @@ def test_get_cloudflare_response_streaming():
|
||||||
|
|
||||||
asyncio.run(test_async_call())
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_completion_tgi():
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
try:
|
||||||
|
response = await acompletion(
|
||||||
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_get_cloudflare_response_streaming()
|
# test_get_cloudflare_response_streaming()
|
||||||
|
|
||||||
|
|
|
@ -923,10 +923,10 @@ def ai21_completion_call_bad_key():
|
||||||
|
|
||||||
# ai21_completion_call_bad_key()
|
# ai21_completion_call_bad_key()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
def hf_test_completion_tgi_stream():
|
async def test_hf_completion_tgi_stream():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = await acompletion(
|
||||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -935,11 +935,13 @@ def hf_test_completion_tgi_stream():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for idx, chunk in enumerate(response):
|
idx = 0
|
||||||
|
async for chunk in response:
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
complete_response += chunk
|
complete_response += chunk
|
||||||
if finished:
|
if finished:
|
||||||
break
|
break
|
||||||
|
idx += 1
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue