diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 0cc8c5697..eb8ce38b9 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -318,6 +318,7 @@ class Huggingface(BaseLLM): headers: Optional[dict], model_response: ModelResponse, print_verbose: Callable, + timeout: float, encoding, api_key, logging_obj, @@ -450,10 +451,10 @@ class Huggingface(BaseLLM): if acompletion is True: ### ASYNC STREAMING if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore + return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore else: ### ASYNC COMPLETION - return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore + return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore ### SYNC STREAMING if "stream" in optional_params and optional_params["stream"] == True: response = requests.post( @@ -560,12 +561,13 @@ class Huggingface(BaseLLM): input_text: str, model: str, optional_params: dict, + timeout: float ): response = None try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post( - url=api_base, json=data, headers=headers, timeout=None + url=api_base, json=data, headers=headers ) response_json = response.json() if response.status_code != 200: @@ -605,8 +607,9 @@ class Huggingface(BaseLLM): headers: dict, model_response: ModelResponse, model: str, + timeout: float ): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=timeout) as client: response = client.stream( "POST", url=f"{api_base}", json=data, headers=headers ) @@ -616,7 +619,6 @@ class Huggingface(BaseLLM): status_code=r.status_code, message="An error occurred while streaming", ) - streamwrapper = CustomStreamWrapper( completion_stream=r.aiter_lines(), model=model, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 3265b230f..229d92f6d 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -836,6 +836,7 @@ class OpenAITextCompletion(BaseLLM): api_key: str, model: str, messages: list, + timeout: float, print_verbose: Optional[Callable] = None, api_base: Optional[str] = None, logging_obj=None, @@ -887,9 +888,10 @@ class OpenAITextCompletion(BaseLLM): headers=headers, model_response=model_response, model=model, + timeout=timeout ) else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore elif optional_params.get("stream", False): return self.streaming( logging_obj=logging_obj, @@ -898,12 +900,14 @@ class OpenAITextCompletion(BaseLLM): headers=headers, model_response=model_response, model=model, + timeout=timeout ) else: response = httpx.post( url=f"{api_base}", json=data, headers=headers, + timeout=timeout ) if response.status_code != 200: raise OpenAIError( @@ -939,8 +943,9 @@ class OpenAITextCompletion(BaseLLM): prompt: str, api_key: str, model: str, + timeout: float ): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=timeout) as client: try: response = await client.post( api_base, @@ -980,13 +985,14 @@ class OpenAITextCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str, + timeout: float ): with httpx.stream( url=f"{api_base}", json=data, headers=headers, method="POST", - timeout=litellm.request_timeout, + timeout=timeout, ) as response: if response.status_code != 200: raise OpenAIError( @@ -1010,6 +1016,7 @@ class OpenAITextCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str, + timeout: float ): client = httpx.AsyncClient() async with client.stream( @@ -1017,7 +1024,7 @@ class OpenAITextCompletion(BaseLLM): json=data, headers=headers, method="POST", - timeout=litellm.request_timeout, + timeout=timeout, ) as response: try: if response.status_code != 200: diff --git a/litellm/main.py b/litellm/main.py index c5340e975..3f1902ec9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -814,6 +814,7 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, + timeout=timeout, ) if optional_params.get("stream", False) or acompletion == True: @@ -1116,6 +1117,7 @@ def completion( acompletion=acompletion, logging_obj=logging, custom_prompt_dict=custom_prompt_dict, + timeout=timeout ) if ( "stream" in optional_params diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 5603e69b7..0218ddfe4 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -197,6 +197,20 @@ def test_get_cloudflare_response_streaming(): asyncio.run(test_async_call()) +@pytest.mark.asyncio +async def test_hf_completion_tgi(): + # litellm.set_verbose=True + try: + response = await acompletion( + model="huggingface/HuggingFaceH4/zephyr-7b-beta", + messages=[{"content": "Hello, how are you?", "role": "user"}], + ) + # Add any assertions here to check the response + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") # test_get_cloudflare_response_streaming() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 398704525..8ad256992 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -923,10 +923,10 @@ def ai21_completion_call_bad_key(): # ai21_completion_call_bad_key() - -def hf_test_completion_tgi_stream(): +@pytest.mark.asyncio +async def test_hf_completion_tgi_stream(): try: - response = completion( + response = await acompletion( model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=[{"content": "Hello, how are you?", "role": "user"}], stream=True, @@ -935,11 +935,13 @@ def hf_test_completion_tgi_stream(): print(f"response: {response}") complete_response = "" start_time = time.time() - for idx, chunk in enumerate(response): + idx = 0 + async for chunk in response: chunk, finished = streaming_format_tests(idx, chunk) complete_response += chunk if finished: break + idx += 1 if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}")