Revert "(refactor) litellm.Router client initialization utils (#6394)" (#6403)

This reverts commit b70147f63b.
This commit is contained in:
Ishaan Jaff 2024-10-23 20:31:57 +05:30 committed by GitHub
parent 72a91ea9dd
commit d063086bbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 478 additions and 884 deletions

View file

@ -63,7 +63,10 @@ from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name,
replace_model_in_jsonl,
)
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
from litellm.router_utils.client_initalization_utils import (
set_client,
should_initialize_sync_client,
)
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
from litellm.router_utils.cooldown_handlers import (
@ -3948,7 +3951,7 @@ class Router:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
InitalizeOpenAISDKClient.set_client(
set_client(
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
)
@ -4658,9 +4661,7 @@ class Router:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@ -4670,9 +4671,7 @@ class Router:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@ -4683,9 +4682,7 @@ class Router:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
else:
@ -4695,9 +4692,7 @@ class Router:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key)
return client

View file

@ -1,11 +1,10 @@
import asyncio
import os
import traceback
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional
import httpx
import openai
from pydantic import BaseModel
import litellm
from litellm import get_secret, get_secret_str
@ -17,38 +16,13 @@ from litellm.secret_managers.get_azure_ad_token_provider import (
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
from httpx import Timeout as httpxTimeout
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
httpxTimeout = Any
class OpenAISDKClientInitializationParams(BaseModel):
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
azure_ad_token_provider: Optional[Callable[[], str]]
timeout: Optional[Union[float, httpxTimeout]]
stream_timeout: Optional[Union[float, httpxTimeout]]
max_retries: int
organization: Optional[str]
# Internal LiteLLM specific params
custom_llm_provider: Optional[str]
model_name: str
class InitalizeOpenAISDKClient:
"""
OpenAI Python SDK requires creating a OpenAI/AzureOpenAI client
this class is responsible for creating that client
"""
@staticmethod
def should_initialize_sync_client(
litellm_router_instance: LitellmRouter,
) -> bool:
@ -67,17 +41,14 @@ class InitalizeOpenAISDKClient:
and hasattr(
litellm_router_instance.router_general_settings, "async_only_mode"
)
and litellm_router_instance.router_general_settings.async_only_mode
is True
and litellm_router_instance.router_general_settings.async_only_mode is True
):
return False
return True
@staticmethod
def set_client( # noqa: PLR0915
litellm_router_instance: LitellmRouter, model: dict
):
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
@ -117,46 +88,111 @@ class InitalizeOpenAISDKClient:
default_api_base = api_base
default_api_key = api_key
if InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
model_name=model_name,
custom_llm_provider=custom_llm_provider,
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
client_initialization_params = (
InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
)
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key
############### Unpack client initialization params #######################
api_key = client_initialization_params.api_key
api_base = client_initialization_params.api_base
api_version: Optional[str] = client_initialization_params.api_version
timeout: Optional[Union[float, httpxTimeout]] = (
client_initialization_params.timeout
api_base = litellm_params.get("api_base")
base_url: Optional[str] = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = get_secret_str(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model is True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = get_secret_str(api_version_env_name)
litellm_params["api_version"] = api_version
timeout: Optional[float] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
stream_timeout: Optional[Union[float, httpxTimeout]] = (
client_initialization_params.stream_timeout
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = get_secret(timeout_env_name) # type: ignore
litellm_params["timeout"] = timeout
stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = get_secret(max_retries_env_name) # type: ignore
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=litellm_params.get("tenant_id"),
client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"),
)
max_retries: int = client_initialization_params.max_retries
organization: Optional[str] = client_initialization_params.organization
azure_ad_token_provider: Optional[Callable[[], str]] = (
client_initialization_params.azure_ad_token_provider
)
custom_llm_provider = client_initialization_params.custom_llm_provider
model_name = client_initialization_params.model_name
##########################################################################
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
k: v for k, v in model["litellm_params"].items() if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
@ -196,8 +232,8 @@ class InitalizeOpenAISDKClient:
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -212,7 +248,7 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
@ -222,8 +258,8 @@ class InitalizeOpenAISDKClient:
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -245,8 +281,8 @@ class InitalizeOpenAISDKClient:
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -261,7 +297,7 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
@ -271,8 +307,8 @@ class InitalizeOpenAISDKClient:
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -319,8 +355,8 @@ class InitalizeOpenAISDKClient:
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -334,14 +370,14 @@ class InitalizeOpenAISDKClient:
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -360,8 +396,8 @@ class InitalizeOpenAISDKClient:
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -376,14 +412,14 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -410,8 +446,8 @@ class InitalizeOpenAISDKClient:
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -427,15 +463,15 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -456,8 +492,8 @@ class InitalizeOpenAISDKClient:
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -473,7 +509,7 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
# streaming clients should have diff timeouts
@ -481,8 +517,8 @@ class InitalizeOpenAISDKClient:
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -498,157 +534,9 @@ class InitalizeOpenAISDKClient:
local_only=True,
) # cache for 1 hr
@staticmethod
def _get_client_initialization_params(
model: dict,
model_name: str,
custom_llm_provider: Optional[str],
litellm_params: dict,
default_api_key: Optional[str],
default_api_base: Optional[str],
) -> OpenAISDKClientInitializationParams:
"""
Returns params to use for initializing OpenAI SDK clients (for OpenAI, Azure OpenAI, OpenAI Compatible Providers)
Args:
model: model dict
model_name: model name
custom_llm_provider: custom llm provider
litellm_params: litellm params
default_api_key: default api key
default_api_base: default api base
Returns:
OpenAISDKClientInitializationParams
"""
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key = get_secret_str(api_key)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url: Optional[str] = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base = get_secret_str(api_base)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model is True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version = get_secret_str(api_version)
litellm_params["api_version"] = api_version
timeout: Optional[Union[float, str, httpxTimeout]] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout = float(get_secret(timeout)) # type: ignore
litellm_params["timeout"] = timeout
stream_timeout: Optional[Union[float, str, httpxTimeout]] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout = float(get_secret(stream_timeout)) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries = get_secret(max_retries) # type: ignore
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization = get_secret_str(organization)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
tenant_id = litellm_params.get("tenant_id")
if tenant_id is not None:
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = (
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
tenant_id=tenant_id,
client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"),
)
)
return OpenAISDKClientInitializationParams(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout, # type: ignore
stream_timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)
@staticmethod
def _should_create_openai_sdk_client_for_model(
model_name: str,
custom_llm_provider: str,
) -> bool:
"""
Returns True if a OpenAI SDK client should be created for a given model
We need a OpenAI SDK client for models that are callsed using OpenAI Python SDK
Azure OpenAI, OpenAI, OpenAI Compatible Providers, OpenAI Embedding Models
"""
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
return True
return False
@staticmethod
def get_azure_ad_token_from_entrata_id(
tenant_id: str, client_id: Optional[str], client_secret: Optional[str]
tenant_id: str, client_id: str, client_secret: str
) -> Callable[[], str]:
from azure.identity import (
ClientSecretCredential,
@ -656,11 +544,6 @@ class InitalizeOpenAISDKClient:
get_bearer_token_provider,
)
if client_id is None or client_secret is None:
raise ValueError(
"client_id and client_secret must be provided when using `tenant_id`"
)
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
if tenant_id.startswith("os.environ/"):

View file

@ -17,10 +17,6 @@ from dotenv import load_dotenv
import litellm
from litellm import Router
from litellm.router_utils.client_initalization_utils import (
InitalizeOpenAISDKClient,
OpenAISDKClientInitializationParams,
)
load_dotenv()
@ -700,283 +696,3 @@ def test_init_router_with_supported_environments(environment, expected_models):
assert set(_model_list) == set(expected_models)
os.environ.pop("LITELLM_ENVIRONMENT")
def test_should_initialize_sync_client():
from litellm.types.router import RouterGeneralSettings
# Test case 1: Router instance is None
assert InitalizeOpenAISDKClient.should_initialize_sync_client(None) is False
# Test case 2: Router instance without router_general_settings
router = Router(model_list=[])
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
# Test case 3: Router instance with async_only_mode = False
router = Router(
model_list=[],
router_general_settings=RouterGeneralSettings(async_only_mode=False),
)
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
# Test case 4: Router instance with async_only_mode = True
router = Router(
model_list=[],
router_general_settings=RouterGeneralSettings(async_only_mode=True),
)
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is False
# Test case 5: Router instance with router_general_settings but without async_only_mode
router = Router(model_list=[], router_general_settings=RouterGeneralSettings())
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
print("All test cases passed!")
@pytest.mark.parametrize(
"model_name, custom_llm_provider, expected_result",
[
("gpt-3.5-turbo", None, True), # OpenAI chat completion model
("text-embedding-ada-002", None, True), # OpenAI embedding model
("claude-2", None, False), # Non-OpenAI model
("gpt-3.5-turbo", "azure", True), # Azure OpenAI
("text-davinci-003", "azure_text", True), # Azure OpenAI
("gpt-3.5-turbo", "openai", True), # OpenAI
("custom-model", "custom_openai", True), # Custom OpenAI compatible
("text-davinci-003", "text-completion-openai", True), # OpenAI text completion
(
"ft:gpt-3.5-turbo-0613:my-org:custom-model:7p4lURel",
None,
True,
), # Fine-tuned GPT model
("mistral-7b", "huggingface", False), # Non-OpenAI provider
("custom-model", "anthropic", False), # Non-OpenAI compatible provider
],
)
def test_should_create_openai_sdk_client_for_model(
model_name, custom_llm_provider, expected_result
):
result = InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
model_name, custom_llm_provider
)
assert (
result == expected_result
), f"Failed for model: {model_name}, provider: {custom_llm_provider}"
def test_should_create_openai_sdk_client_for_model_openai_compatible_providers():
# Test with a known OpenAI compatible provider
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
"custom-model", "groq"
), "Should return True for OpenAI compatible provider"
# Add a new compatible provider and test
litellm.openai_compatible_providers.append("new_provider")
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
"custom-model", "new_provider"
), "Should return True for newly added OpenAI compatible provider"
# Clean up
litellm.openai_compatible_providers.remove("new_provider")
def test_get_client_initialization_params_openai():
"""Test basic OpenAI configuration with direct parameter passing."""
model = {}
model_name = "gpt-3.5-turbo"
custom_llm_provider = None
litellm_params = {"api_key": "sk-openai-key", "timeout": 30, "max_retries": 3}
default_api_key = None
default_api_base = None
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
assert isinstance(result, OpenAISDKClientInitializationParams)
assert result.api_key == "sk-openai-key"
assert result.timeout == 30
assert result.max_retries == 3
assert result.model_name == "gpt-3.5-turbo"
def test_get_client_initialization_params_azure():
"""Test Azure OpenAI configuration with specific Azure parameters."""
model = {}
model_name = "azure/gpt-4"
custom_llm_provider = "azure"
litellm_params = {
"api_key": "azure-key",
"api_base": "https://example.azure.openai.com",
"api_version": "2023-05-15",
}
default_api_key = None
default_api_base = None
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
assert result.api_key == "azure-key"
assert result.api_base == "https://example.azure.openai.com"
assert result.api_version == "2023-05-15"
assert result.custom_llm_provider == "azure"
def test_get_client_initialization_params_environment_variable_parsing():
"""Test parsing of environment variables for configuration."""
os.environ["UNIQUE_OPENAI_API_KEY"] = "env-openai-key"
os.environ["UNIQUE_TIMEOUT"] = "45"
model = {}
model_name = "gpt-4"
custom_llm_provider = None
litellm_params = {
"api_key": "os.environ/UNIQUE_OPENAI_API_KEY",
"timeout": "os.environ/UNIQUE_TIMEOUT",
"organization": "os.environ/UNIQUE_ORG_ID",
}
default_api_key = None
default_api_base = None
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
assert result.api_key == "env-openai-key"
assert result.timeout == 45.0
assert result.organization is None # Since ORG_ID is not set in the environment
def test_get_client_initialization_params_azure_ai_studio_mistral():
"""
Test configuration for Azure AI Studio Mistral model.
- /v1/ is added to the api_base if it is not present
- custom_llm_provider is set to openai (Azure AI Studio Mistral models need to use OpenAI route)
"""
model = {}
model_name = "azure/mistral-large-latest"
custom_llm_provider = "azure"
litellm_params = {
"api_key": "azure-key",
"api_base": "https://example.azure.openai.com",
}
default_api_key = None
default_api_base = None
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model,
model_name,
custom_llm_provider,
litellm_params,
default_api_key,
default_api_base,
)
assert result.custom_llm_provider == "openai"
assert result.model_name == "mistral-large-latest"
assert result.api_base == "https://example.azure.openai.com/v1/"
def test_get_client_initialization_params_default_values():
"""
Test use of default values when specific parameters are not provided.
This is used typically for OpenAI compatible providers - example Together AI
"""
model = {}
model_name = "together/meta-llama-3.1-8b-instruct"
custom_llm_provider = None
litellm_params = {}
default_api_key = "together-api-key"
default_api_base = "https://together.xyz/api.openai.com"
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
assert result.api_key == "together-api-key"
assert result.api_base == "https://together.xyz/api.openai.com"
assert result.timeout == litellm.request_timeout
assert result.max_retries == 0
def test_get_client_initialization_params_all_env_vars():
# Set up environment variables
os.environ["TEST_API_KEY"] = "test-api-key"
os.environ["TEST_API_BASE"] = "https://test.openai.com"
os.environ["TEST_API_VERSION"] = "2023-05-15"
os.environ["TEST_TIMEOUT"] = "30"
os.environ["TEST_STREAM_TIMEOUT"] = "60"
os.environ["TEST_MAX_RETRIES"] = "3"
os.environ["TEST_ORGANIZATION"] = "test-org"
model = {}
model_name = "gpt-4"
custom_llm_provider = None
litellm_params = {
"api_key": "os.environ/TEST_API_KEY",
"api_base": "os.environ/TEST_API_BASE",
"api_version": "os.environ/TEST_API_VERSION",
"timeout": "os.environ/TEST_TIMEOUT",
"stream_timeout": "os.environ/TEST_STREAM_TIMEOUT",
"max_retries": "os.environ/TEST_MAX_RETRIES",
"organization": "os.environ/TEST_ORGANIZATION",
}
default_api_key = None
default_api_base = None
result = InitalizeOpenAISDKClient._get_client_initialization_params(
model=model,
model_name=model_name,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
default_api_key=default_api_key,
default_api_base=default_api_base,
)
assert isinstance(result, OpenAISDKClientInitializationParams)
assert result.api_key == "test-api-key"
assert result.api_base == "https://test.openai.com"
assert result.api_version == "2023-05-15"
assert result.timeout == 30.0
assert result.stream_timeout == 60.0
assert result.max_retries == 3
assert result.organization == "test-org"
assert result.model_name == "gpt-4"
assert result.custom_llm_provider is None
# Clean up environment variables
for key in [
"TEST_API_KEY",
"TEST_API_BASE",
"TEST_API_VERSION",
"TEST_TIMEOUT",
"TEST_STREAM_TIMEOUT",
"TEST_MAX_RETRIES",
"TEST_ORGANIZATION",
]:
os.environ.pop(key)