forked from phoenix/litellm-mirror
(refactor) router - use static methods for client init utils (#6420)
* use InitalizeOpenAISDKClient * use InitalizeOpenAISDKClient static method * fix # noqa: PLR0915
This commit is contained in:
parent
cdda7c243f
commit
17e81d861c
2 changed files with 448 additions and 426 deletions
|
@ -63,10 +63,7 @@ from litellm.router_utils.batch_utils import (
|
|||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
set_client,
|
||||
should_initialize_sync_client,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||
from litellm.router_utils.cooldown_handlers import (
|
||||
|
@ -3951,7 +3948,7 @@ class Router:
|
|||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
set_client(
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||
)
|
||||
|
||||
|
@ -4661,7 +4658,9 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
|
@ -4671,7 +4670,9 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
|
@ -4682,7 +4683,9 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
else:
|
||||
|
@ -4692,7 +4695,9 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ else:
|
|||
LitellmRouter = Any
|
||||
|
||||
|
||||
class InitalizeOpenAISDKClient:
|
||||
@staticmethod
|
||||
def should_initialize_sync_client(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
) -> bool:
|
||||
|
@ -41,14 +43,17 @@ def should_initialize_sync_client(
|
|||
and hasattr(
|
||||
litellm_router_instance.router_general_settings, "async_only_mode"
|
||||
)
|
||||
and litellm_router_instance.router_general_settings.async_only_mode is True
|
||||
and litellm_router_instance.router_general_settings.async_only_mode
|
||||
is True
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
|
||||
@staticmethod
|
||||
def set_client( # noqa: PLR0915
|
||||
litellm_router_instance: LitellmRouter, model: dict
|
||||
):
|
||||
"""
|
||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||
|
@ -110,7 +115,11 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
||||
if (
|
||||
api_key
|
||||
and isinstance(api_key, str)
|
||||
and api_key.startswith("os.environ/")
|
||||
):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = get_secret_str(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
@ -162,7 +171,9 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
stream_timeout: Optional[float] = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
@ -182,17 +193,23 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
verbose_router_logger.debug(
|
||||
"Using Azure AD Token Provider for Azure Auth"
|
||||
)
|
||||
azure_ad_token_provider = (
|
||||
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=litellm_params.get("tenant_id"),
|
||||
client_id=litellm_params.get("client_id"),
|
||||
client_secret=litellm_params.get("client_secret"),
|
||||
)
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
filtered_litellm_params = {
|
||||
k: v for k, v in model["litellm_params"].items() if k != "api_key"
|
||||
k: v
|
||||
for k, v in model["litellm_params"].items()
|
||||
if k != "api_key"
|
||||
}
|
||||
_filtered_model = {
|
||||
"model_name": model["model_name"],
|
||||
|
@ -248,7 +265,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
|
@ -297,7 +314,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
|
@ -370,7 +387,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
|
@ -412,7 +429,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
|
@ -463,7 +480,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
|
@ -509,7 +526,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if should_initialize_sync_client(
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
# streaming clients should have diff timeouts
|
||||
|
@ -534,7 +551,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_azure_ad_token_from_entrata_id(
|
||||
tenant_id: str, client_id: str, client_secret: str
|
||||
) -> Callable[[], str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue