From a0c5fb81b8c449ddbab807aacac2040b51b6654c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 18 Mar 2025 10:23:30 -0700 Subject: [PATCH] fix logic for intializing openai clients --- litellm/llms/azure/azure.py | 3 +++ litellm/llms/azure/common_utils.py | 4 ++-- litellm/llms/openai/common_utils.py | 29 ++++++++++++++++++++++++++++ litellm/llms/openai/openai.py | 30 ++++++----------------------- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 6a16b50c31..94e3d4732a 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -234,6 +234,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): api_base=api_base, model_name=model, api_version=api_version, + is_async=False, ) ### CHECK IF CLOUDFLARE AI GATEWAY ### ### if so - set the model as part of the base url @@ -749,6 +750,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): model_name=model, api_version=api_version, api_base=api_base, + is_async=False, ) ## LOGGING logging_obj.pre_call( @@ -1152,6 +1154,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): model_name=model or "", api_version=api_version, api_base=api_base, + is_async=False, ) if aimg_generation is True: return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index a7a1954e8f..fa41b07973 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -9,7 +9,7 @@ import litellm from litellm._logging import verbose_logger from litellm.caching.caching import DualCache from litellm.llms.base_llm.chat.transformation import BaseLLMException -from litellm.llms.openai.openai import OpenAIChatCompletion +from litellm.llms.openai.common_utils import BaseOpenAILLM from litellm.secret_managers.get_azure_ad_token_provider import ( get_azure_ad_token_provider, ) @@ -245,7 +245,7 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): return azure_client_params -class BaseAzureLLM(OpenAIChatCompletion): +class BaseAzureLLM(BaseOpenAILLM): def get_azure_openai_client( self, api_key: Optional[str], diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index a8412f867b..99a6c8837a 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union import httpx import openai +import litellm from litellm.llms.base_llm.chat.transformation import BaseLLMException @@ -92,3 +93,31 @@ def drop_params_from_unprocessable_entity_error( new_data = {k: v for k, v in data.items() if k not in invalid_params} return new_data + + +class BaseOpenAILLM: + """ + Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings + """ + + @staticmethod + def _get_async_http_client() -> Optional[httpx.AsyncClient]: + if litellm.ssl_verify: + return httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ) + return litellm.aclient_session + + @staticmethod + def _get_sync_http_client() -> Optional[httpx.Client]: + if litellm.ssl_verify: + return httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ) + return litellm.client_session diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index a5e33795c8..8045b5ef32 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -50,7 +50,11 @@ from litellm.utils import ( from ...types.llms.openai import * from ..base import BaseLLM from .chat.o_series_transformation import OpenAIOSeriesConfig -from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error +from .common_utils import ( + BaseOpenAILLM, + OpenAIError, + drop_params_from_unprocessable_entity_error, +) openaiOSeriesConfig = OpenAIOSeriesConfig() @@ -317,7 +321,7 @@ class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator): raise e -class OpenAIChatCompletion(BaseLLM): +class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): def __init__(self) -> None: super().__init__() @@ -401,28 +405,6 @@ class OpenAIChatCompletion(BaseLLM): ) return client - @staticmethod - def _get_async_http_client() -> Optional[httpx.AsyncClient]: - if litellm.ssl_verify: - return httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ) - return litellm.aclient_session - - @staticmethod - def _get_sync_http_client() -> Optional[httpx.Client]: - if litellm.ssl_verify: - return httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ) - return litellm.client_session - @track_llm_api_timing() async def make_openai_chat_completion_request( self,