From 9fd7c5b3438c38dad79c93acb9a46a689cbd808d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 15 Nov 2023 17:41:04 -0800 Subject: [PATCH] test: set request timeout at request level --- litellm/__init__.py | 2 +- litellm/llms/azure.py | 40 +++++++++++++++++--------- litellm/llms/base.py | 4 +-- litellm/llms/openai.py | 49 ++++++++++++++++++++------------ litellm/tests/test_completion.py | 2 +- litellm/utils.py | 4 +-- 6 files changed, 63 insertions(+), 38 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index f2c3e779d8..d4e72c5d94 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -19,7 +19,7 @@ telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = False retry = True -request_timeout: Optional[float] = None +request_timeout: Optional[float] = 600 api_key: Optional[str] = None openai_key: Optional[str] = None azure_key: Optional[str] = None diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 2991a64158..0dd23e13d9 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -4,6 +4,7 @@ from .base import BaseLLM from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from typing import Callable, Optional from litellm import OpenAIConfig +import litellm import httpx class AzureOpenAIError(Exception): @@ -105,6 +106,8 @@ class AzureChatCompletion(BaseLLM): acompletion: bool = False, headers: Optional[dict]=None): super().completion() + if self._client_session is None: + self._client_session = self.create_client_session() exception_mapping_worked = False try: if headers is None: @@ -140,10 +143,11 @@ class AzureChatCompletion(BaseLLM): elif "stream" in optional_params and optional_params["stream"] == True: return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) else: - response = httpx.post( + response = self._client_session.post( url=api_base, json=data, headers=headers, + timeout=litellm.request_timeout ) if response.status_code != 200: raise AzureOpenAIError(status_code=response.status_code, message=response.text) @@ -157,15 +161,17 @@ class AzureChatCompletion(BaseLLM): raise e async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + client = self._aclient_session try: - async with httpx.AsyncClient() as client: - response = await client.post(api_base, json=data, headers=headers) - response_json = response.json() - if response.status_code != 200: - raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) + response_json = response.json() + if response.status_code != 200: + raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) + + ## RESPONSE OBJECT + return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) except Exception as e: if isinstance(e,httpx.TimeoutException): raise AzureOpenAIError(status_code=500, message="Request Timeout Error") @@ -182,11 +188,14 @@ class AzureChatCompletion(BaseLLM): model_response: ModelResponse, model: str ): - with httpx.stream( + if self._client_session is None: + self._client_session = self.create_client_session() + with self._client_session.stream( url=f"{api_base}", json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming") @@ -203,12 +212,15 @@ class AzureChatCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str): - client = httpx.AsyncClient() + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + client = self._aclient_session async with client.stream( url=f"{api_base}", json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise AzureOpenAIError(status_code=response.status_code, message=response.text) @@ -253,7 +265,7 @@ class AzureChatCompletion(BaseLLM): ) ## COMPLETION CALL response = self._client_session.post( - api_base, headers=headers, json=data + api_base, headers=headers, json=data, timeout=litellm.request_timeout ) ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/base.py b/litellm/llms/base.py index d93b5a3f63..a4c056e1af 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -9,7 +9,7 @@ class BaseLLM: if litellm.client_session: _client_session = litellm.client_session else: - _client_session = httpx.Client(timeout=litellm.request_timeout) + _client_session = httpx.Client() return _client_session @@ -17,7 +17,7 @@ class BaseLLM: if litellm.aclient_session: _aclient_session = litellm.aclient_session else: - _aclient_session = httpx.AsyncClient(timeout=litellm.request_timeout) + _aclient_session = httpx.AsyncClient() return _aclient_session diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index e025c66efb..aaa6271074 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -5,6 +5,7 @@ from .base import BaseLLM from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage from typing import Callable, Optional import aiohttp +import litellm class OpenAIError(Exception): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): @@ -223,10 +224,11 @@ class OpenAIChatCompletion(BaseLLM): elif optional_params.get("stream", False): return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) else: - response = httpx.post( + response = self._client_session.post( url=api_base, json=data, headers=headers, + timeout=litellm.request_timeout ) if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) @@ -262,15 +264,18 @@ class OpenAIChatCompletion(BaseLLM): api_base: str, data: dict, headers: dict, model_response: ModelResponse): + kwargs = locals() + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + client = self._aclient_session try: - async with httpx.AsyncClient() as client: - response = await client.post(api_base, json=data, headers=headers) - response_json = response.json() - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) + response_json = response.json() + if response.status_code != 200: + raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) + + ## RESPONSE OBJECT + return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) except Exception as e: if isinstance(e, httpx.TimeoutException): raise OpenAIError(status_code=500, message="Request Timeout Error") @@ -287,11 +292,14 @@ class OpenAIChatCompletion(BaseLLM): model_response: ModelResponse, model: str ): - with httpx.stream( + if self._client_session is None: + self._client_session = self.create_client_session() + with self._client_session.stream( url=f"{api_base}", # type: ignore json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore @@ -308,12 +316,14 @@ class OpenAIChatCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str): - client = httpx.AsyncClient() - async with client.stream( + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + async with self._aclient_session.stream( url=f"{api_base}", json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore @@ -352,7 +362,7 @@ class OpenAIChatCompletion(BaseLLM): ) ## COMPLETION CALL response = self._client_session.post( - api_base, headers=headers, json=data + api_base, headers=headers, json=data, timeout=litellm.request_timeout ) ## LOGGING logging_obj.post_call( @@ -483,6 +493,7 @@ class OpenAITextCompletion(BaseLLM): url=f"{api_base}", json=data, headers=headers, + timeout=litellm.request_timeout ) if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text) @@ -513,7 +524,7 @@ class OpenAITextCompletion(BaseLLM): api_key: str, model: str): async with httpx.AsyncClient() as client: - response = await client.post(api_base, json=data, headers=headers) + response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) response_json = response.json() if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text) @@ -544,7 +555,8 @@ class OpenAITextCompletion(BaseLLM): url=f"{api_base}", json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text) @@ -565,7 +577,8 @@ class OpenAITextCompletion(BaseLLM): url=f"{api_base}", json=data, headers=headers, - method="POST" + method="POST", + timeout=litellm.request_timeout ) as response: if response.status_code != 200: raise OpenAIError(status_code=response.status_code, message=response.text) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index ec5c3ef5b2..72202be5c7 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -419,7 +419,7 @@ def test_completion_openai(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_openai() +test_completion_openai() def test_completion_text_openai(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 3f1002155e..a79dc8e8fe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1188,7 +1188,7 @@ def client(original_function): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated - logging_obj.success_handler(result, start_time, end_time) + threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() my_thread = threading.Thread( target=handle_success, args=(args, kwargs, result, start_time, end_time) @@ -1292,7 +1292,7 @@ def client(original_function): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated - logging_obj.success_handler(result, start_time, end_time) + threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT return result except Exception as e: