fix: fix max parallel requests client

This commit is contained in:
Krrish Dholakia 2025-03-11 18:25:48 -07:00
parent 3ba683be88
commit e4fc6422e2
2 changed files with 23 additions and 9 deletions

View file

@ -5346,6 +5346,13 @@ class Router:
client = self.cache.get_cache( client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span key=cache_key, local_only=True, parent_otel_span=parent_otel_span
) )
if client is None:
InitalizeOpenAISDKClient.set_max_parallel_requests_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
return client return client
elif client_type == "async": elif client_type == "async":
if kwargs.get("stream") is True: if kwargs.get("stream") is True:

View file

@ -54,18 +54,11 @@ class InitalizeOpenAISDKClient:
return True return True
@staticmethod @staticmethod
def set_client( # noqa: PLR0915 def set_max_parallel_requests_client(
litellm_router_instance: LitellmRouter, model: dict litellm_router_instance: LitellmRouter, model: dict
): ):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = litellm_router_instance.client_ttl
litellm_params = model.get("litellm_params", {}) litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"] model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None) rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None) tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None) max_parallel_requests = litellm_params.get("max_parallel_requests", None)
@ -84,6 +77,19 @@ class InitalizeOpenAISDKClient:
local_only=True, local_only=True,
) )
@staticmethod
def set_client( # noqa: PLR0915
litellm_router_instance: LitellmRouter, model: dict
):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = litellm_router_instance.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
#### for OpenAI / Azure we need to initalize the Client for High Traffic ######## #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -233,7 +239,8 @@ class InitalizeOpenAISDKClient:
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif ( elif (
not api_key and azure_ad_token_provider is None not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True and litellm.enable_azure_ad_token_refresh is True
): ):
try: try: