(refactor) router - use static methods for client init utils (#6420)

* use InitalizeOpenAISDKClient

* use InitalizeOpenAISDKClient static method

* fix  # noqa: PLR0915
This commit is contained in:
Ishaan Jaff 2024-10-24 19:26:46 +04:00 committed by GitHub
parent cdda7c243f
commit 17e81d861c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 448 additions and 426 deletions

View file

@ -63,10 +63,7 @@ from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name,
replace_model_in_jsonl,
)
from litellm.router_utils.client_initalization_utils import (
set_client,
should_initialize_sync_client,
)
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
from litellm.router_utils.cooldown_handlers import (
@ -3951,7 +3948,7 @@ class Router:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
set_client(
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
)
@ -4661,7 +4658,9 @@ class Router:
"""
Re-initialize the client
"""
set_client(litellm_router_instance=self, model=deployment)
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@ -4671,7 +4670,9 @@ class Router:
"""
Re-initialize the client
"""
set_client(litellm_router_instance=self, model=deployment)
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@ -4682,7 +4683,9 @@ class Router:
"""
Re-initialize the client
"""
set_client(litellm_router_instance=self, model=deployment)
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key)
return client
else:
@ -4692,7 +4695,9 @@ class Router:
"""
Re-initialize the client
"""
set_client(litellm_router_instance=self, model=deployment)
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key)
return client

View file

@ -23,6 +23,8 @@ else:
LitellmRouter = Any
class InitalizeOpenAISDKClient:
@staticmethod
def should_initialize_sync_client(
litellm_router_instance: LitellmRouter,
) -> bool:
@ -41,14 +43,17 @@ def should_initialize_sync_client(
and hasattr(
litellm_router_instance.router_general_settings, "async_only_mode"
)
and litellm_router_instance.router_general_settings.async_only_mode is True
and litellm_router_instance.router_general_settings.async_only_mode
is True
):
return False
return True
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
@staticmethod
def set_client( # noqa: PLR0915
litellm_router_instance: LitellmRouter, model: dict
):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
@ -110,7 +115,11 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key
@ -162,7 +171,9 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
@ -182,17 +193,23 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
verbose_router_logger.debug(
"Using Azure AD Token Provider for Azure Auth"
)
azure_ad_token_provider = (
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
tenant_id=litellm_params.get("tenant_id"),
client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"),
)
)
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v for k, v in model["litellm_params"].items() if k != "api_key"
k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
@ -248,7 +265,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
@ -297,7 +314,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
@ -370,7 +387,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
@ -412,7 +429,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
@ -463,7 +480,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
@ -509,7 +526,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
# streaming clients should have diff timeouts
@ -534,7 +551,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True,
) # cache for 1 hr
@staticmethod
def get_azure_ad_token_from_entrata_id(
tenant_id: str, client_id: str, client_secret: str
) -> Callable[[], str]: