mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
This reverts commit b70147f63b
.
This commit is contained in:
parent
72a91ea9dd
commit
d063086bbf
3 changed files with 478 additions and 884 deletions
|
@ -63,7 +63,10 @@ from litellm.router_utils.batch_utils import (
|
|||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
set_client,
|
||||
should_initialize_sync_client,
|
||||
)
|
||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||
from litellm.router_utils.cooldown_handlers import (
|
||||
|
@ -3948,7 +3951,7 @@ class Router:
|
|||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
set_client(
|
||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||
)
|
||||
|
||||
|
@ -4658,9 +4661,7 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
|
@ -4670,9 +4671,7 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
|
@ -4683,9 +4682,7 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
else:
|
||||
|
@ -4695,9 +4692,7 @@ class Router:
|
|||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
set_client(litellm_router_instance=self, model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret, get_secret_str
|
||||
|
@ -17,38 +16,13 @@ from litellm.secret_managers.get_azure_ad_token_provider import (
|
|||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import Timeout as httpxTimeout
|
||||
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
httpxTimeout = Any
|
||||
|
||||
|
||||
class OpenAISDKClientInitializationParams(BaseModel):
|
||||
api_key: Optional[str]
|
||||
api_base: Optional[str]
|
||||
api_version: Optional[str]
|
||||
azure_ad_token_provider: Optional[Callable[[], str]]
|
||||
timeout: Optional[Union[float, httpxTimeout]]
|
||||
stream_timeout: Optional[Union[float, httpxTimeout]]
|
||||
max_retries: int
|
||||
organization: Optional[str]
|
||||
|
||||
# Internal LiteLLM specific params
|
||||
custom_llm_provider: Optional[str]
|
||||
model_name: str
|
||||
|
||||
|
||||
class InitalizeOpenAISDKClient:
|
||||
"""
|
||||
OpenAI Python SDK requires creating a OpenAI/AzureOpenAI client
|
||||
this class is responsible for creating that client
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def should_initialize_sync_client(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
) -> bool:
|
||||
|
@ -67,17 +41,14 @@ class InitalizeOpenAISDKClient:
|
|||
and hasattr(
|
||||
litellm_router_instance.router_general_settings, "async_only_mode"
|
||||
)
|
||||
and litellm_router_instance.router_general_settings.async_only_mode
|
||||
is True
|
||||
and litellm_router_instance.router_general_settings.async_only_mode is True
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_client( # noqa: PLR0915
|
||||
litellm_router_instance: LitellmRouter, model: dict
|
||||
):
|
||||
|
||||
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
|
||||
"""
|
||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||
|
@ -117,46 +88,111 @@ class InitalizeOpenAISDKClient:
|
|||
default_api_base = api_base
|
||||
default_api_key = api_key
|
||||
|
||||
if InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "azure_text"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
client_initialization_params = (
|
||||
InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
)
|
||||
is_azure_ai_studio_model: bool = False
|
||||
if custom_llm_provider == "azure":
|
||||
if litellm.utils._is_non_openai_azure_model(model_name):
|
||||
is_azure_ai_studio_model = True
|
||||
custom_llm_provider = "openai"
|
||||
# remove azure prefx from model_name
|
||||
model_name = model_name.replace("azure/", "")
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = get_secret_str(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
############### Unpack client initialization params #######################
|
||||
api_key = client_initialization_params.api_key
|
||||
api_base = client_initialization_params.api_base
|
||||
api_version: Optional[str] = client_initialization_params.api_version
|
||||
timeout: Optional[Union[float, httpxTimeout]] = (
|
||||
client_initialization_params.timeout
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url: Optional[str] = litellm_params.get("base_url")
|
||||
api_base = (
|
||||
api_base or base_url or default_api_base
|
||||
) # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = get_secret_str(api_base_env_name)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
## AZURE AI STUDIO MISTRAL CHECK ##
|
||||
"""
|
||||
Make sure api base ends in /v1/
|
||||
|
||||
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
||||
"""
|
||||
if (
|
||||
is_azure_ai_studio_model is True
|
||||
and api_base is not None
|
||||
and isinstance(api_base, str)
|
||||
and not api_base.endswith("/v1/")
|
||||
):
|
||||
# check if it ends with a trailing slash
|
||||
if api_base.endswith("/"):
|
||||
api_base += "v1/"
|
||||
elif api_base.endswith("/v1"):
|
||||
api_base += "/"
|
||||
else:
|
||||
api_base += "/v1/"
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = get_secret_str(api_version_env_name)
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout: Optional[float] = (
|
||||
litellm_params.pop("timeout", None) or litellm.request_timeout
|
||||
)
|
||||
stream_timeout: Optional[Union[float, httpxTimeout]] = (
|
||||
client_initialization_params.stream_timeout
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = timeout.replace("os.environ/", "")
|
||||
timeout = get_secret(timeout_env_name) # type: ignore
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout: Optional[float] = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries: Optional[int] = litellm_params.pop(
|
||||
"max_retries", 0
|
||||
) # router handles retry logic
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = get_secret(max_retries_env_name) # type: ignore
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
organization = litellm_params.get("organization", None)
|
||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = get_secret_str(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=litellm_params.get("tenant_id"),
|
||||
client_id=litellm_params.get("client_id"),
|
||||
client_secret=litellm_params.get("client_secret"),
|
||||
)
|
||||
max_retries: int = client_initialization_params.max_retries
|
||||
organization: Optional[str] = client_initialization_params.organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = (
|
||||
client_initialization_params.azure_ad_token_provider
|
||||
)
|
||||
custom_llm_provider = client_initialization_params.custom_llm_provider
|
||||
model_name = client_initialization_params.model_name
|
||||
##########################################################################
|
||||
|
||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
filtered_litellm_params = {
|
||||
k: v
|
||||
for k, v in model["litellm_params"].items()
|
||||
if k != "api_key"
|
||||
k: v for k, v in model["litellm_params"].items() if k != "api_key"
|
||||
}
|
||||
_filtered_model = {
|
||||
"model_name": model["model_name"],
|
||||
|
@ -196,8 +232,8 @@ class InitalizeOpenAISDKClient:
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -212,7 +248,7 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
|
@ -222,8 +258,8 @@ class InitalizeOpenAISDKClient:
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -245,8 +281,8 @@ class InitalizeOpenAISDKClient:
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -261,7 +297,7 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
|
@ -271,8 +307,8 @@ class InitalizeOpenAISDKClient:
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -319,8 +355,8 @@ class InitalizeOpenAISDKClient:
|
|||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -334,14 +370,14 @@ class InitalizeOpenAISDKClient:
|
|||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -360,8 +396,8 @@ class InitalizeOpenAISDKClient:
|
|||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -376,14 +412,14 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -410,8 +446,8 @@ class InitalizeOpenAISDKClient:
|
|||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
|
@ -427,15 +463,15 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
|
@ -456,8 +492,8 @@ class InitalizeOpenAISDKClient:
|
|||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
|
@ -473,7 +509,7 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
if should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
# streaming clients should have diff timeouts
|
||||
|
@ -481,8 +517,8 @@ class InitalizeOpenAISDKClient:
|
|||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
|
@ -498,157 +534,9 @@ class InitalizeOpenAISDKClient:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
@staticmethod
|
||||
def _get_client_initialization_params(
|
||||
model: dict,
|
||||
model_name: str,
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_params: dict,
|
||||
default_api_key: Optional[str],
|
||||
default_api_base: Optional[str],
|
||||
) -> OpenAISDKClientInitializationParams:
|
||||
"""
|
||||
Returns params to use for initializing OpenAI SDK clients (for OpenAI, Azure OpenAI, OpenAI Compatible Providers)
|
||||
|
||||
Args:
|
||||
model: model dict
|
||||
model_name: model name
|
||||
custom_llm_provider: custom llm provider
|
||||
litellm_params: litellm params
|
||||
default_api_key: default api key
|
||||
default_api_base: default api base
|
||||
|
||||
Returns:
|
||||
OpenAISDKClientInitializationParams
|
||||
"""
|
||||
|
||||
is_azure_ai_studio_model: bool = False
|
||||
if custom_llm_provider == "azure":
|
||||
if litellm.utils._is_non_openai_azure_model(model_name):
|
||||
is_azure_ai_studio_model = True
|
||||
custom_llm_provider = "openai"
|
||||
# remove azure prefx from model_name
|
||||
model_name = model_name.replace("azure/", "")
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
||||
api_key = get_secret_str(api_key)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url: Optional[str] = litellm_params.get("base_url")
|
||||
api_base = (
|
||||
api_base or base_url or default_api_base
|
||||
) # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base = get_secret_str(api_base)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
## AZURE AI STUDIO MISTRAL CHECK ##
|
||||
"""
|
||||
Make sure api base ends in /v1/
|
||||
|
||||
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
||||
"""
|
||||
if (
|
||||
is_azure_ai_studio_model is True
|
||||
and api_base is not None
|
||||
and isinstance(api_base, str)
|
||||
and not api_base.endswith("/v1/")
|
||||
):
|
||||
# check if it ends with a trailing slash
|
||||
if api_base.endswith("/"):
|
||||
api_base += "v1/"
|
||||
elif api_base.endswith("/v1"):
|
||||
api_base += "/"
|
||||
else:
|
||||
api_base += "/v1/"
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version = get_secret_str(api_version)
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout: Optional[Union[float, str, httpxTimeout]] = (
|
||||
litellm_params.pop("timeout", None) or litellm.request_timeout
|
||||
)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout = float(get_secret(timeout)) # type: ignore
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout: Optional[Union[float, str, httpxTimeout]] = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout = float(get_secret(stream_timeout)) # type: ignore
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries: Optional[int] = litellm_params.pop(
|
||||
"max_retries", 0
|
||||
) # router handles retry logic
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries = get_secret(max_retries) # type: ignore
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
organization = litellm_params.get("organization", None)
|
||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||
organization = get_secret_str(organization)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
tenant_id = litellm_params.get("tenant_id")
|
||||
if tenant_id is not None:
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = (
|
||||
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=litellm_params.get("client_id"),
|
||||
client_secret=litellm_params.get("client_secret"),
|
||||
)
|
||||
)
|
||||
|
||||
return OpenAISDKClientInitializationParams(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
timeout=timeout, # type: ignore
|
||||
stream_timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_create_openai_sdk_client_for_model(
|
||||
model_name: str,
|
||||
custom_llm_provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if a OpenAI SDK client should be created for a given model
|
||||
|
||||
We need a OpenAI SDK client for models that are callsed using OpenAI Python SDK
|
||||
Azure OpenAI, OpenAI, OpenAI Compatible Providers, OpenAI Embedding Models
|
||||
"""
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "azure_text"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_azure_ad_token_from_entrata_id(
|
||||
tenant_id: str, client_id: Optional[str], client_secret: Optional[str]
|
||||
tenant_id: str, client_id: str, client_secret: str
|
||||
) -> Callable[[], str]:
|
||||
from azure.identity import (
|
||||
ClientSecretCredential,
|
||||
|
@ -656,11 +544,6 @@ class InitalizeOpenAISDKClient:
|
|||
get_bearer_token_provider,
|
||||
)
|
||||
|
||||
if client_id is None or client_secret is None:
|
||||
raise ValueError(
|
||||
"client_id and client_secret must be provided when using `tenant_id`"
|
||||
)
|
||||
|
||||
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
|
||||
|
||||
if tenant_id.startswith("os.environ/"):
|
||||
|
|
|
@ -17,10 +17,6 @@ from dotenv import load_dotenv
|
|||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
InitalizeOpenAISDKClient,
|
||||
OpenAISDKClientInitializationParams,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -700,283 +696,3 @@ def test_init_router_with_supported_environments(environment, expected_models):
|
|||
assert set(_model_list) == set(expected_models)
|
||||
|
||||
os.environ.pop("LITELLM_ENVIRONMENT")
|
||||
|
||||
|
||||
def test_should_initialize_sync_client():
|
||||
from litellm.types.router import RouterGeneralSettings
|
||||
|
||||
# Test case 1: Router instance is None
|
||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(None) is False
|
||||
|
||||
# Test case 2: Router instance without router_general_settings
|
||||
router = Router(model_list=[])
|
||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
||||
|
||||
# Test case 3: Router instance with async_only_mode = False
|
||||
router = Router(
|
||||
model_list=[],
|
||||
router_general_settings=RouterGeneralSettings(async_only_mode=False),
|
||||
)
|
||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
||||
|
||||
# Test case 4: Router instance with async_only_mode = True
|
||||
router = Router(
|
||||
model_list=[],
|
||||
router_general_settings=RouterGeneralSettings(async_only_mode=True),
|
||||
)
|
||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is False
|
||||
|
||||
# Test case 5: Router instance with router_general_settings but without async_only_mode
|
||||
router = Router(model_list=[], router_general_settings=RouterGeneralSettings())
|
||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
||||
|
||||
print("All test cases passed!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, custom_llm_provider, expected_result",
|
||||
[
|
||||
("gpt-3.5-turbo", None, True), # OpenAI chat completion model
|
||||
("text-embedding-ada-002", None, True), # OpenAI embedding model
|
||||
("claude-2", None, False), # Non-OpenAI model
|
||||
("gpt-3.5-turbo", "azure", True), # Azure OpenAI
|
||||
("text-davinci-003", "azure_text", True), # Azure OpenAI
|
||||
("gpt-3.5-turbo", "openai", True), # OpenAI
|
||||
("custom-model", "custom_openai", True), # Custom OpenAI compatible
|
||||
("text-davinci-003", "text-completion-openai", True), # OpenAI text completion
|
||||
(
|
||||
"ft:gpt-3.5-turbo-0613:my-org:custom-model:7p4lURel",
|
||||
None,
|
||||
True,
|
||||
), # Fine-tuned GPT model
|
||||
("mistral-7b", "huggingface", False), # Non-OpenAI provider
|
||||
("custom-model", "anthropic", False), # Non-OpenAI compatible provider
|
||||
],
|
||||
)
|
||||
def test_should_create_openai_sdk_client_for_model(
|
||||
model_name, custom_llm_provider, expected_result
|
||||
):
|
||||
result = InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
||||
model_name, custom_llm_provider
|
||||
)
|
||||
assert (
|
||||
result == expected_result
|
||||
), f"Failed for model: {model_name}, provider: {custom_llm_provider}"
|
||||
|
||||
|
||||
def test_should_create_openai_sdk_client_for_model_openai_compatible_providers():
|
||||
# Test with a known OpenAI compatible provider
|
||||
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
||||
"custom-model", "groq"
|
||||
), "Should return True for OpenAI compatible provider"
|
||||
|
||||
# Add a new compatible provider and test
|
||||
litellm.openai_compatible_providers.append("new_provider")
|
||||
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
||||
"custom-model", "new_provider"
|
||||
), "Should return True for newly added OpenAI compatible provider"
|
||||
|
||||
# Clean up
|
||||
litellm.openai_compatible_providers.remove("new_provider")
|
||||
|
||||
|
||||
def test_get_client_initialization_params_openai():
|
||||
"""Test basic OpenAI configuration with direct parameter passing."""
|
||||
model = {}
|
||||
model_name = "gpt-3.5-turbo"
|
||||
custom_llm_provider = None
|
||||
litellm_params = {"api_key": "sk-openai-key", "timeout": 30, "max_retries": 3}
|
||||
default_api_key = None
|
||||
default_api_base = None
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
|
||||
assert isinstance(result, OpenAISDKClientInitializationParams)
|
||||
assert result.api_key == "sk-openai-key"
|
||||
assert result.timeout == 30
|
||||
assert result.max_retries == 3
|
||||
assert result.model_name == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_get_client_initialization_params_azure():
|
||||
"""Test Azure OpenAI configuration with specific Azure parameters."""
|
||||
model = {}
|
||||
model_name = "azure/gpt-4"
|
||||
custom_llm_provider = "azure"
|
||||
litellm_params = {
|
||||
"api_key": "azure-key",
|
||||
"api_base": "https://example.azure.openai.com",
|
||||
"api_version": "2023-05-15",
|
||||
}
|
||||
default_api_key = None
|
||||
default_api_base = None
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
|
||||
assert result.api_key == "azure-key"
|
||||
assert result.api_base == "https://example.azure.openai.com"
|
||||
assert result.api_version == "2023-05-15"
|
||||
assert result.custom_llm_provider == "azure"
|
||||
|
||||
|
||||
def test_get_client_initialization_params_environment_variable_parsing():
|
||||
"""Test parsing of environment variables for configuration."""
|
||||
os.environ["UNIQUE_OPENAI_API_KEY"] = "env-openai-key"
|
||||
os.environ["UNIQUE_TIMEOUT"] = "45"
|
||||
|
||||
model = {}
|
||||
model_name = "gpt-4"
|
||||
custom_llm_provider = None
|
||||
litellm_params = {
|
||||
"api_key": "os.environ/UNIQUE_OPENAI_API_KEY",
|
||||
"timeout": "os.environ/UNIQUE_TIMEOUT",
|
||||
"organization": "os.environ/UNIQUE_ORG_ID",
|
||||
}
|
||||
default_api_key = None
|
||||
default_api_base = None
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
|
||||
assert result.api_key == "env-openai-key"
|
||||
assert result.timeout == 45.0
|
||||
assert result.organization is None # Since ORG_ID is not set in the environment
|
||||
|
||||
|
||||
def test_get_client_initialization_params_azure_ai_studio_mistral():
|
||||
"""
|
||||
Test configuration for Azure AI Studio Mistral model.
|
||||
|
||||
- /v1/ is added to the api_base if it is not present
|
||||
- custom_llm_provider is set to openai (Azure AI Studio Mistral models need to use OpenAI route)
|
||||
"""
|
||||
|
||||
model = {}
|
||||
model_name = "azure/mistral-large-latest"
|
||||
custom_llm_provider = "azure"
|
||||
litellm_params = {
|
||||
"api_key": "azure-key",
|
||||
"api_base": "https://example.azure.openai.com",
|
||||
}
|
||||
default_api_key = None
|
||||
default_api_base = None
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model,
|
||||
model_name,
|
||||
custom_llm_provider,
|
||||
litellm_params,
|
||||
default_api_key,
|
||||
default_api_base,
|
||||
)
|
||||
|
||||
assert result.custom_llm_provider == "openai"
|
||||
assert result.model_name == "mistral-large-latest"
|
||||
assert result.api_base == "https://example.azure.openai.com/v1/"
|
||||
|
||||
|
||||
def test_get_client_initialization_params_default_values():
|
||||
"""
|
||||
Test use of default values when specific parameters are not provided.
|
||||
|
||||
This is used typically for OpenAI compatible providers - example Together AI
|
||||
|
||||
"""
|
||||
model = {}
|
||||
model_name = "together/meta-llama-3.1-8b-instruct"
|
||||
custom_llm_provider = None
|
||||
litellm_params = {}
|
||||
default_api_key = "together-api-key"
|
||||
default_api_base = "https://together.xyz/api.openai.com"
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
|
||||
assert result.api_key == "together-api-key"
|
||||
assert result.api_base == "https://together.xyz/api.openai.com"
|
||||
assert result.timeout == litellm.request_timeout
|
||||
assert result.max_retries == 0
|
||||
|
||||
|
||||
def test_get_client_initialization_params_all_env_vars():
|
||||
# Set up environment variables
|
||||
os.environ["TEST_API_KEY"] = "test-api-key"
|
||||
os.environ["TEST_API_BASE"] = "https://test.openai.com"
|
||||
os.environ["TEST_API_VERSION"] = "2023-05-15"
|
||||
os.environ["TEST_TIMEOUT"] = "30"
|
||||
os.environ["TEST_STREAM_TIMEOUT"] = "60"
|
||||
os.environ["TEST_MAX_RETRIES"] = "3"
|
||||
os.environ["TEST_ORGANIZATION"] = "test-org"
|
||||
|
||||
model = {}
|
||||
model_name = "gpt-4"
|
||||
custom_llm_provider = None
|
||||
litellm_params = {
|
||||
"api_key": "os.environ/TEST_API_KEY",
|
||||
"api_base": "os.environ/TEST_API_BASE",
|
||||
"api_version": "os.environ/TEST_API_VERSION",
|
||||
"timeout": "os.environ/TEST_TIMEOUT",
|
||||
"stream_timeout": "os.environ/TEST_STREAM_TIMEOUT",
|
||||
"max_retries": "os.environ/TEST_MAX_RETRIES",
|
||||
"organization": "os.environ/TEST_ORGANIZATION",
|
||||
}
|
||||
default_api_key = None
|
||||
default_api_base = None
|
||||
|
||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
default_api_key=default_api_key,
|
||||
default_api_base=default_api_base,
|
||||
)
|
||||
|
||||
assert isinstance(result, OpenAISDKClientInitializationParams)
|
||||
assert result.api_key == "test-api-key"
|
||||
assert result.api_base == "https://test.openai.com"
|
||||
assert result.api_version == "2023-05-15"
|
||||
assert result.timeout == 30.0
|
||||
assert result.stream_timeout == 60.0
|
||||
assert result.max_retries == 3
|
||||
assert result.organization == "test-org"
|
||||
assert result.model_name == "gpt-4"
|
||||
assert result.custom_llm_provider is None
|
||||
|
||||
# Clean up environment variables
|
||||
for key in [
|
||||
"TEST_API_KEY",
|
||||
"TEST_API_BASE",
|
||||
"TEST_API_VERSION",
|
||||
"TEST_TIMEOUT",
|
||||
"TEST_STREAM_TIMEOUT",
|
||||
"TEST_MAX_RETRIES",
|
||||
"TEST_ORGANIZATION",
|
||||
]:
|
||||
os.environ.pop(key)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue