fix: fix merge conflicts

This commit is contained in:
Krrish Dholakia 2025-03-11 18:41:41 -07:00
parent dd8e50527e
commit 9a98942e87
3 changed files with 5 additions and 252 deletions

View file

@ -71,6 +71,7 @@ from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name, _get_router_metadata_variable_name,
replace_model_in_jsonl, replace_model_in_jsonl,
) )
from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
from litellm.router_utils.clientside_credential_handler import ( from litellm.router_utils.clientside_credential_handler import (
get_dynamic_litellm_params, get_dynamic_litellm_params,
is_clientside_credential, is_clientside_credential,
@ -5346,7 +5347,7 @@ class Router:
key=cache_key, local_only=True, parent_otel_span=parent_otel_span key=cache_key, local_only=True, parent_otel_span=parent_otel_span
) )
if client is None: if client is None:
InitalizeOpenAISDKClient.set_max_parallel_requests_client( InitalizeCachedClient.set_max_parallel_requests_client(
litellm_router_instance=self, model=deployment litellm_router_instance=self, model=deployment
) )
client = self.cache.get_cache( client = self.cache.get_cache(

View file

@ -17,33 +17,7 @@ else:
LitellmRouter = Any LitellmRouter = Any
class InitalizeOpenAISDKClient: class InitalizeCachedClient:
@staticmethod
def should_initialize_sync_client(
litellm_router_instance: LitellmRouter,
) -> bool:
"""
Returns if Sync OpenAI, Azure Clients should be initialized.
Do not init sync clients when router.router_general_settings.async_only_mode is True
"""
if litellm_router_instance is None:
return False
if litellm_router_instance.router_general_settings is not None:
if (
hasattr(litellm_router_instance, "router_general_settings")
and hasattr(
litellm_router_instance.router_general_settings, "async_only_mode"
)
and litellm_router_instance.router_general_settings.async_only_mode
is True
):
return False
return True
@staticmethod @staticmethod
def set_max_parallel_requests_client( def set_max_parallel_requests_client(
litellm_router_instance: LitellmRouter, model: dict litellm_router_instance: LitellmRouter, model: dict
@ -67,226 +41,3 @@ class InitalizeOpenAISDKClient:
value=semaphore, value=semaphore,
local_only=True, local_only=True,
) )
@staticmethod
def set_client( # noqa: PLR0915
litellm_router_instance: LitellmRouter, model: dict
):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = litellm_router_instance.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url: Optional[str] = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = get_secret_str(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model is True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = get_secret_str(api_version_env_name)
litellm_params["api_version"] = api_version
timeout: Optional[float] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = get_secret(timeout_env_name) # type: ignore
litellm_params["timeout"] = timeout
stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = get_secret(max_retries_env_name) # type: ignore
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
else:
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr

View file

@ -1,6 +1,7 @@
import json import json
import os import os
import sys import sys
import traceback
from typing import Callable, Optional from typing import Callable, Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -346,7 +347,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
azure_ad_token="oidc/test-token", azure_ad_token="oidc/test-token",
) )
except Exception as e: except Exception as e:
print(e) traceback.print_exc()
# Verify initialize_azure_sdk_client was called # Verify initialize_azure_sdk_client was called
mock_init_azure.assert_called_once() mock_init_azure.assert_called_once()