forked from phoenix/litellm-mirror
(refactor) router - use static methods for client init utils (#6420)
* use InitalizeOpenAISDKClient * use InitalizeOpenAISDKClient static method * fix # noqa: PLR0915
This commit is contained in:
parent
cdda7c243f
commit
17e81d861c
2 changed files with 448 additions and 426 deletions
|
@ -63,10 +63,7 @@ from litellm.router_utils.batch_utils import (
|
||||||
_get_router_metadata_variable_name,
|
_get_router_metadata_variable_name,
|
||||||
replace_model_in_jsonl,
|
replace_model_in_jsonl,
|
||||||
)
|
)
|
||||||
from litellm.router_utils.client_initalization_utils import (
|
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
||||||
set_client,
|
|
||||||
should_initialize_sync_client,
|
|
||||||
)
|
|
||||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||||
from litellm.router_utils.cooldown_handlers import (
|
from litellm.router_utils.cooldown_handlers import (
|
||||||
|
@ -3951,7 +3948,7 @@ class Router:
|
||||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||||
|
|
||||||
# init OpenAI, Azure clients
|
# init OpenAI, Azure clients
|
||||||
set_client(
|
InitalizeOpenAISDKClient.set_client(
|
||||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4661,7 +4658,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
set_client(litellm_router_instance=self, model=deployment)
|
InitalizeOpenAISDKClient.set_client(
|
||||||
|
litellm_router_instance=self, model=deployment
|
||||||
|
)
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4671,7 +4670,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
set_client(litellm_router_instance=self, model=deployment)
|
InitalizeOpenAISDKClient.set_client(
|
||||||
|
litellm_router_instance=self, model=deployment
|
||||||
|
)
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4682,7 +4683,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
set_client(litellm_router_instance=self, model=deployment)
|
InitalizeOpenAISDKClient.set_client(
|
||||||
|
litellm_router_instance=self, model=deployment
|
||||||
|
)
|
||||||
client = self.cache.get_cache(key=cache_key)
|
client = self.cache.get_cache(key=cache_key)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4692,7 +4695,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
set_client(litellm_router_instance=self, model=deployment)
|
InitalizeOpenAISDKClient.set_client(
|
||||||
|
litellm_router_instance=self, model=deployment
|
||||||
|
)
|
||||||
client = self.cache.get_cache(key=cache_key)
|
client = self.cache.get_cache(key=cache_key)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,11 @@ else:
|
||||||
LitellmRouter = Any
|
LitellmRouter = Any
|
||||||
|
|
||||||
|
|
||||||
def should_initialize_sync_client(
|
class InitalizeOpenAISDKClient:
|
||||||
|
@staticmethod
|
||||||
|
def should_initialize_sync_client(
|
||||||
litellm_router_instance: LitellmRouter,
|
litellm_router_instance: LitellmRouter,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns if Sync OpenAI, Azure Clients should be initialized.
|
Returns if Sync OpenAI, Azure Clients should be initialized.
|
||||||
|
|
||||||
|
@ -41,14 +43,17 @@ def should_initialize_sync_client(
|
||||||
and hasattr(
|
and hasattr(
|
||||||
litellm_router_instance.router_general_settings, "async_only_mode"
|
litellm_router_instance.router_general_settings, "async_only_mode"
|
||||||
)
|
)
|
||||||
and litellm_router_instance.router_general_settings.async_only_mode is True
|
and litellm_router_instance.router_general_settings.async_only_mode
|
||||||
|
is True
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
|
def set_client( # noqa: PLR0915
|
||||||
|
litellm_router_instance: LitellmRouter, model: dict
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
|
@ -110,7 +115,11 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
api_key = litellm_params.get("api_key") or default_api_key
|
api_key = litellm_params.get("api_key") or default_api_key
|
||||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
if (
|
||||||
|
api_key
|
||||||
|
and isinstance(api_key, str)
|
||||||
|
and api_key.startswith("os.environ/")
|
||||||
|
):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
api_key = get_secret_str(api_key_env_name)
|
api_key = get_secret_str(api_key_env_name)
|
||||||
litellm_params["api_key"] = api_key
|
litellm_params["api_key"] = api_key
|
||||||
|
@ -162,7 +171,9 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
stream_timeout: Optional[float] = litellm_params.pop(
|
stream_timeout: Optional[float] = litellm_params.pop(
|
||||||
"stream_timeout", timeout
|
"stream_timeout", timeout
|
||||||
) # if no stream_timeout is set, default to timeout
|
) # if no stream_timeout is set, default to timeout
|
||||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
||||||
|
"os.environ/"
|
||||||
|
):
|
||||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||||
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
||||||
litellm_params["stream_timeout"] = stream_timeout
|
litellm_params["stream_timeout"] = stream_timeout
|
||||||
|
@ -182,17 +193,23 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||||
if litellm_params.get("tenant_id"):
|
if litellm_params.get("tenant_id"):
|
||||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
verbose_router_logger.debug(
|
||||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
"Using Azure AD Token Provider for Azure Auth"
|
||||||
|
)
|
||||||
|
azure_ad_token_provider = (
|
||||||
|
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
|
||||||
tenant_id=litellm_params.get("tenant_id"),
|
tenant_id=litellm_params.get("tenant_id"),
|
||||||
client_id=litellm_params.get("client_id"),
|
client_id=litellm_params.get("client_id"),
|
||||||
client_secret=litellm_params.get("client_secret"),
|
client_secret=litellm_params.get("client_secret"),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||||
if api_base is None or not isinstance(api_base, str):
|
if api_base is None or not isinstance(api_base, str):
|
||||||
filtered_litellm_params = {
|
filtered_litellm_params = {
|
||||||
k: v for k, v in model["litellm_params"].items() if k != "api_key"
|
k: v
|
||||||
|
for k, v in model["litellm_params"].items()
|
||||||
|
if k != "api_key"
|
||||||
}
|
}
|
||||||
_filtered_model = {
|
_filtered_model = {
|
||||||
"model_name": model["model_name"],
|
"model_name": model["model_name"],
|
||||||
|
@ -248,7 +265,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
|
@ -297,7 +314,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
@ -370,7 +387,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
ttl=client_ttl,
|
ttl=client_ttl,
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
|
@ -412,7 +429,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
@ -463,7 +480,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
|
@ -509,7 +526,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if should_initialize_sync_client(
|
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
# streaming clients should have diff timeouts
|
# streaming clients should have diff timeouts
|
||||||
|
@ -534,10 +551,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def get_azure_ad_token_from_entrata_id(
|
def get_azure_ad_token_from_entrata_id(
|
||||||
tenant_id: str, client_id: str, client_secret: str
|
tenant_id: str, client_id: str, client_secret: str
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
from azure.identity import (
|
from azure.identity import (
|
||||||
ClientSecretCredential,
|
ClientSecretCredential,
|
||||||
DefaultAzureCredential,
|
DefaultAzureCredential,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue