diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 9379f5042..c7613017e 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -247,7 +247,7 @@ class AzureChatCompletion(BaseLLM): azure_client = AzureOpenAI(**azure_client_params) else: azure_client = client - response = azure_client.chat.completions.create(**data) # type: ignore + response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -290,6 +290,7 @@ class AzureChatCompletion(BaseLLM): raise AzureOpenAIError( status_code=422, message="max retries must be an int" ) + # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -318,7 +319,9 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create(**data) + response = await azure_client.chat.completions.create( + **data, timeout=timeout + ) return convert_to_model_response_object( response_object=json.loads(response.model_dump_json()), model_response_object=model_response, @@ -377,7 +380,7 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = azure_client.chat.completions.create(**data) + response = azure_client.chat.completions.create(**data, timeout=timeout) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -427,7 +430,9 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create(**data) + response = await azure_client.chat.completions.create( + **data, timeout=timeout + ) # return response streamwrapper = CustomStreamWrapper( completion_stream=response, @@ -451,6 +456,7 @@ class AzureChatCompletion(BaseLLM): input: list, client=None, logging_obj=None, + timeout=None, ): response = None try: @@ -458,7 +464,7 @@ class AzureChatCompletion(BaseLLM): openai_aclient = AsyncAzureOpenAI(**azure_client_params) else: openai_aclient = client - response = await openai_aclient.embeddings.create(**data) + response = await openai_aclient.embeddings.create(**data, timeout=timeout) stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -541,6 +547,7 @@ class AzureChatCompletion(BaseLLM): api_key=api_key, model_response=model_response, azure_client_params=azure_client_params, + timeout=timeout, ) return response if client is None: @@ -548,7 +555,7 @@ class AzureChatCompletion(BaseLLM): else: azure_client = client ## COMPLETION CALL - response = azure_client.embeddings.create(**data) # type: ignore + response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=input, @@ -578,6 +585,7 @@ class AzureChatCompletion(BaseLLM): input: list, client=None, logging_obj=None, + timeout=None, ): response = None try: @@ -590,7 +598,7 @@ class AzureChatCompletion(BaseLLM): ) else: openai_aclient = client - response = await openai_aclient.images.generate(**data) + response = await openai_aclient.images.generate(**data, timeout=timeout) stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -656,7 +664,7 @@ class AzureChatCompletion(BaseLLM): azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation == True: - response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore return response if client is None: @@ -680,7 +688,7 @@ class AzureChatCompletion(BaseLLM): ) ## COMPLETION CALL - response = azure_client.images.generate(**data) # type: ignore + response = azure_client.images.generate(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=input, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index cc2a3889a..bbdbd591d 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -306,7 +306,7 @@ class OpenAIChatCompletion(BaseLLM): ) else: openai_client = client - response = openai_client.chat.completions.create(**data) # type: ignore + response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump_json() logging_obj.post_call( input=messages, @@ -383,7 +383,9 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create(**data) + response = await openai_aclient.chat.completions.create( + **data, timeout=timeout + ) stringified_response = response.model_dump_json() logging_obj.post_call( input=data["messages"], @@ -431,7 +433,7 @@ class OpenAIChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = openai_client.chat.completions.create(**data) + response = openai_client.chat.completions.create(**data, timeout=timeout) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -476,7 +478,9 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create(**data) + response = await openai_aclient.chat.completions.create( + **data, timeout=timeout + ) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -522,7 +526,7 @@ class OpenAIChatCompletion(BaseLLM): ) else: openai_aclient = client - response = await openai_aclient.embeddings.create(**data) # type: ignore + response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -584,7 +588,7 @@ class OpenAIChatCompletion(BaseLLM): openai_client = client ## COMPLETION CALL - response = openai_client.embeddings.create(**data) # type: ignore + response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=input, @@ -629,7 +633,7 @@ class OpenAIChatCompletion(BaseLLM): ) else: openai_aclient = client - response = await openai_aclient.images.generate(**data) # type: ignore + response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -669,9 +673,9 @@ class OpenAIChatCompletion(BaseLLM): if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") - # if aembedding == True: - # response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore - # return response + if aimg_generation == True: + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + return response if client is None: openai_client = OpenAI( @@ -697,7 +701,7 @@ class OpenAIChatCompletion(BaseLLM): ) ## COMPLETION CALL - response = openai_client.images.generate(**data) # type: ignore + response = openai_client.images.generate(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=input, diff --git a/litellm/tests/test_timeout.py b/litellm/tests/test_timeout.py index ae495c056..49eeff57c 100644 --- a/litellm/tests/test_timeout.py +++ b/litellm/tests/test_timeout.py @@ -10,7 +10,7 @@ sys.path.insert( import time import litellm import openai -import pytest +import pytest, uuid def test_timeout(): @@ -60,7 +60,7 @@ def test_hanging_request_azure(): encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0] response = router.completion( model="azure-gpt", - messages=[{"role": "user", "content": "what color is red"}], + messages=[{"role": "user", "content": f"what color is red {uuid.uuid4()}"}], logit_bias={encoded: 100}, timeout=0.01, ) @@ -126,7 +126,7 @@ def test_hanging_request_openai(): ) -test_hanging_request_openai() +# test_hanging_request_openai() # test_timeout() @@ -155,4 +155,4 @@ def test_timeout_streaming(): ) -test_timeout_streaming() +# test_timeout_streaming()