mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix logic for intializing openai clients
This commit is contained in:
parent
7bb90a15be
commit
1ede6080ef
4 changed files with 40 additions and 26 deletions
|
@ -234,6 +234,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
is_async=False,
|
||||||
)
|
)
|
||||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||||
### if so - set the model as part of the base url
|
### if so - set the model as part of the base url
|
||||||
|
@ -749,6 +750,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
is_async=False,
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -1152,6 +1154,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
model_name=model or "",
|
model_name=model or "",
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
is_async=False,
|
||||||
)
|
)
|
||||||
if aimg_generation is True:
|
if aimg_generation is True:
|
||||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
||||||
|
|
|
@ -9,7 +9,7 @@ import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.llms.openai.openai import OpenAIChatCompletion
|
from litellm.llms.openai.common_utils import BaseOpenAILLM
|
||||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||||
get_azure_ad_token_provider,
|
get_azure_ad_token_provider,
|
||||||
)
|
)
|
||||||
|
@ -245,7 +245,7 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
return azure_client_params
|
return azure_client_params
|
||||||
|
|
||||||
|
|
||||||
class BaseAzureLLM(OpenAIChatCompletion):
|
class BaseAzureLLM(BaseOpenAILLM):
|
||||||
def get_azure_openai_client(
|
def get_azure_openai_client(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,3 +93,31 @@ def drop_params_from_unprocessable_entity_error(
|
||||||
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
||||||
|
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOpenAILLM:
|
||||||
|
"""
|
||||||
|
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
||||||
|
if litellm.ssl_verify:
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
|
)
|
||||||
|
return litellm.aclient_session
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_sync_http_client() -> Optional[httpx.Client]:
|
||||||
|
if litellm.ssl_verify:
|
||||||
|
return httpx.Client(
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
),
|
||||||
|
verify=litellm.ssl_verify,
|
||||||
|
)
|
||||||
|
return litellm.client_session
|
||||||
|
|
|
@ -50,7 +50,11 @@ from litellm.utils import (
|
||||||
from ...types.llms.openai import *
|
from ...types.llms.openai import *
|
||||||
from ..base import BaseLLM
|
from ..base import BaseLLM
|
||||||
from .chat.o_series_transformation import OpenAIOSeriesConfig
|
from .chat.o_series_transformation import OpenAIOSeriesConfig
|
||||||
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
from .common_utils import (
|
||||||
|
BaseOpenAILLM,
|
||||||
|
OpenAIError,
|
||||||
|
drop_params_from_unprocessable_entity_error,
|
||||||
|
)
|
||||||
|
|
||||||
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
||||||
|
|
||||||
|
@ -317,7 +321,7 @@ class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletion(BaseLLM):
|
class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -401,28 +405,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
|
||||||
if litellm.ssl_verify:
|
|
||||||
return httpx.AsyncClient(
|
|
||||||
limits=httpx.Limits(
|
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
)
|
|
||||||
return litellm.aclient_session
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_sync_http_client() -> Optional[httpx.Client]:
|
|
||||||
if litellm.ssl_verify:
|
|
||||||
return httpx.Client(
|
|
||||||
limits=httpx.Limits(
|
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
)
|
|
||||||
return litellm.client_session
|
|
||||||
|
|
||||||
@track_llm_api_timing()
|
@track_llm_api_timing()
|
||||||
async def make_openai_chat_completion_request(
|
async def make_openai_chat_completion_request(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue