(refactor) router - use static methods for client init utils (#6420)

* use InitalizeOpenAISDKClient

* use InitalizeOpenAISDKClient static method

* fix  # noqa: PLR0915
This commit is contained in:
Ishaan Jaff 2024-10-24 19:26:46 +04:00 committed by GitHub
parent cdda7c243f
commit 17e81d861c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 448 additions and 426 deletions

View file

@ -63,10 +63,7 @@ from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name, _get_router_metadata_variable_name,
replace_model_in_jsonl, replace_model_in_jsonl,
) )
from litellm.router_utils.client_initalization_utils import ( from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
set_client,
should_initialize_sync_client,
)
from litellm.router_utils.cooldown_cache import CooldownCache from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
from litellm.router_utils.cooldown_handlers import ( from litellm.router_utils.cooldown_handlers import (
@ -3951,7 +3948,7 @@ class Router:
raise Exception(f"Unsupported provider - {custom_llm_provider}") raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients # init OpenAI, Azure clients
set_client( InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment.to_json(exclude_none=True) litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
) )
@ -4661,7 +4658,9 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
set_client(litellm_router_instance=self, model=deployment) InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(key=cache_key, local_only=True)
return client return client
else: else:
@ -4671,7 +4670,9 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
set_client(litellm_router_instance=self, model=deployment) InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(key=cache_key, local_only=True)
return client return client
else: else:
@ -4682,7 +4683,9 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
set_client(litellm_router_instance=self, model=deployment) InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(key=cache_key)
return client return client
else: else:
@ -4692,7 +4695,9 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
set_client(litellm_router_instance=self, model=deployment) InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(key=cache_key)
return client return client

View file

@ -23,6 +23,8 @@ else:
LitellmRouter = Any LitellmRouter = Any
class InitalizeOpenAISDKClient:
@staticmethod
def should_initialize_sync_client( def should_initialize_sync_client(
litellm_router_instance: LitellmRouter, litellm_router_instance: LitellmRouter,
) -> bool: ) -> bool:
@ -41,14 +43,17 @@ def should_initialize_sync_client(
and hasattr( and hasattr(
litellm_router_instance.router_general_settings, "async_only_mode" litellm_router_instance.router_general_settings, "async_only_mode"
) )
and litellm_router_instance.router_general_settings.async_only_mode is True and litellm_router_instance.router_general_settings.async_only_mode
is True
): ):
return False return False
return True return True
@staticmethod
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915 def set_client( # noqa: PLR0915
litellm_router_instance: LitellmRouter, model: dict
):
""" """
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
@ -110,7 +115,11 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key # we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "") api_key_env_name = api_key.replace("os.environ/", "")
api_key = get_secret_str(api_key_env_name) api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key litellm_params["api_key"] = api_key
@ -162,7 +171,9 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
stream_timeout: Optional[float] = litellm_params.pop( stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout "stream_timeout", timeout
) # if no stream_timeout is set, default to timeout ) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "") stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout litellm_params["stream_timeout"] = stream_timeout
@ -182,17 +193,23 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
litellm_params["organization"] = organization litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"): if litellm_params.get("tenant_id"):
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth") verbose_router_logger.debug(
azure_ad_token_provider = get_azure_ad_token_from_entrata_id( "Using Azure AD Token Provider for Azure Auth"
)
azure_ad_token_provider = (
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
tenant_id=litellm_params.get("tenant_id"), tenant_id=litellm_params.get("tenant_id"),
client_id=litellm_params.get("client_id"), client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"), client_secret=litellm_params.get("client_secret"),
) )
)
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str): if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = { filtered_litellm_params = {
k: v for k, v in model["litellm_params"].items() if k != "api_key" k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
} }
_filtered_model = { _filtered_model = {
"model_name": model["model_name"], "model_name": model["model_name"],
@ -248,7 +265,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
@ -297,7 +314,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
@ -370,7 +387,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
ttl=client_ttl, ttl=client_ttl,
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
@ -412,7 +429,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
@ -463,7 +480,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
@ -509,7 +526,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
if should_initialize_sync_client( if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance litellm_router_instance=litellm_router_instance
): ):
# streaming clients should have diff timeouts # streaming clients should have diff timeouts
@ -534,7 +551,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
@staticmethod
def get_azure_ad_token_from_entrata_id( def get_azure_ad_token_from_entrata_id(
tenant_id: str, client_id: str, client_secret: str tenant_id: str, client_id: str, client_secret: str
) -> Callable[[], str]: ) -> Callable[[], str]: