fix logic for intializing openai clients

This commit is contained in:
Ishaan Jaff 2025-03-18 10:23:30 -07:00
parent 0601768bb8
commit a0c5fb81b8
4 changed files with 40 additions and 26 deletions

View file

@ -234,6 +234,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
api_base=api_base,
model_name=model,
api_version=api_version,
is_async=False,
)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
@ -749,6 +750,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
model_name=model,
api_version=api_version,
api_base=api_base,
is_async=False,
)
## LOGGING
logging_obj.pre_call(
@ -1152,6 +1154,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
model_name=model or "",
api_version=api_version,
api_base=api_base,
is_async=False,
)
if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore

View file

@ -9,7 +9,7 @@ import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.llms.openai.common_utils import BaseOpenAILLM
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
@ -245,7 +245,7 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params
class BaseAzureLLM(OpenAIChatCompletion):
class BaseAzureLLM(BaseOpenAILLM):
def get_azure_openai_client(
self,
api_key: Optional[str],

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import httpx
import openai
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
@ -92,3 +93,31 @@ def drop_params_from_unprocessable_entity_error(
new_data = {k: v for k, v in data.items() if k not in invalid_params}
return new_data
class BaseOpenAILLM:
"""
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
"""
@staticmethod
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
if litellm.ssl_verify:
return httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
)
return litellm.aclient_session
@staticmethod
def _get_sync_http_client() -> Optional[httpx.Client]:
if litellm.ssl_verify:
return httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
)
return litellm.client_session

View file

@ -50,7 +50,11 @@ from litellm.utils import (
from ...types.llms.openai import *
from ..base import BaseLLM
from .chat.o_series_transformation import OpenAIOSeriesConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
from .common_utils import (
BaseOpenAILLM,
OpenAIError,
drop_params_from_unprocessable_entity_error,
)
openaiOSeriesConfig = OpenAIOSeriesConfig()
@ -317,7 +321,7 @@ class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
raise e
class OpenAIChatCompletion(BaseLLM):
class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
def __init__(self) -> None:
super().__init__()
@ -401,28 +405,6 @@ class OpenAIChatCompletion(BaseLLM):
)
return client
@staticmethod
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
if litellm.ssl_verify:
return httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
)
return litellm.aclient_session
@staticmethod
def _get_sync_http_client() -> Optional[httpx.Client]:
if litellm.ssl_verify:
return httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
)
return litellm.client_session
@track_llm_api_timing()
async def make_openai_chat_completion_request(
self,