forked from phoenix/litellm-mirror
This reverts commit b70147f63b
.
This commit is contained in:
parent
72a91ea9dd
commit
d063086bbf
3 changed files with 478 additions and 884 deletions
|
@ -63,7 +63,10 @@ from litellm.router_utils.batch_utils import (
|
||||||
_get_router_metadata_variable_name,
|
_get_router_metadata_variable_name,
|
||||||
replace_model_in_jsonl,
|
replace_model_in_jsonl,
|
||||||
)
|
)
|
||||||
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
from litellm.router_utils.client_initalization_utils import (
|
||||||
|
set_client,
|
||||||
|
should_initialize_sync_client,
|
||||||
|
)
|
||||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||||
from litellm.router_utils.cooldown_handlers import (
|
from litellm.router_utils.cooldown_handlers import (
|
||||||
|
@ -3948,7 +3951,7 @@ class Router:
|
||||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||||
|
|
||||||
# init OpenAI, Azure clients
|
# init OpenAI, Azure clients
|
||||||
InitalizeOpenAISDKClient.set_client(
|
set_client(
|
||||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4658,9 +4661,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
InitalizeOpenAISDKClient.set_client(
|
set_client(litellm_router_instance=self, model=deployment)
|
||||||
litellm_router_instance=self, model=deployment
|
|
||||||
)
|
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4670,9 +4671,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
InitalizeOpenAISDKClient.set_client(
|
set_client(litellm_router_instance=self, model=deployment)
|
||||||
litellm_router_instance=self, model=deployment
|
|
||||||
)
|
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4683,9 +4682,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
InitalizeOpenAISDKClient.set_client(
|
set_client(litellm_router_instance=self, model=deployment)
|
||||||
litellm_router_instance=self, model=deployment
|
|
||||||
)
|
|
||||||
client = self.cache.get_cache(key=cache_key)
|
client = self.cache.get_cache(key=cache_key)
|
||||||
return client
|
return client
|
||||||
else:
|
else:
|
||||||
|
@ -4695,9 +4692,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Re-initialize the client
|
Re-initialize the client
|
||||||
"""
|
"""
|
||||||
InitalizeOpenAISDKClient.set_client(
|
set_client(litellm_router_instance=self, model=deployment)
|
||||||
litellm_router_instance=self, model=deployment
|
|
||||||
)
|
|
||||||
client = self.cache.get_cache(key=cache_key)
|
client = self.cache.get_cache(key=cache_key)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import get_secret, get_secret_str
|
from litellm import get_secret, get_secret_str
|
||||||
|
@ -17,38 +16,13 @@ from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||||
from litellm.utils import calculate_max_parallel_requests
|
from litellm.utils import calculate_max_parallel_requests
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from httpx import Timeout as httpxTimeout
|
|
||||||
|
|
||||||
from litellm.router import Router as _Router
|
from litellm.router import Router as _Router
|
||||||
|
|
||||||
LitellmRouter = _Router
|
LitellmRouter = _Router
|
||||||
else:
|
else:
|
||||||
LitellmRouter = Any
|
LitellmRouter = Any
|
||||||
httpxTimeout = Any
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAISDKClientInitializationParams(BaseModel):
|
|
||||||
api_key: Optional[str]
|
|
||||||
api_base: Optional[str]
|
|
||||||
api_version: Optional[str]
|
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]]
|
|
||||||
timeout: Optional[Union[float, httpxTimeout]]
|
|
||||||
stream_timeout: Optional[Union[float, httpxTimeout]]
|
|
||||||
max_retries: int
|
|
||||||
organization: Optional[str]
|
|
||||||
|
|
||||||
# Internal LiteLLM specific params
|
|
||||||
custom_llm_provider: Optional[str]
|
|
||||||
model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class InitalizeOpenAISDKClient:
|
|
||||||
"""
|
|
||||||
OpenAI Python SDK requires creating a OpenAI/AzureOpenAI client
|
|
||||||
this class is responsible for creating that client
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def should_initialize_sync_client(
|
def should_initialize_sync_client(
|
||||||
litellm_router_instance: LitellmRouter,
|
litellm_router_instance: LitellmRouter,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
@ -67,17 +41,14 @@ class InitalizeOpenAISDKClient:
|
||||||
and hasattr(
|
and hasattr(
|
||||||
litellm_router_instance.router_general_settings, "async_only_mode"
|
litellm_router_instance.router_general_settings, "async_only_mode"
|
||||||
)
|
)
|
||||||
and litellm_router_instance.router_general_settings.async_only_mode
|
and litellm_router_instance.router_general_settings.async_only_mode is True
|
||||||
is True
|
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_client( # noqa: PLR0915
|
def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915
|
||||||
litellm_router_instance: LitellmRouter, model: dict
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
|
@ -117,46 +88,111 @@ class InitalizeOpenAISDKClient:
|
||||||
default_api_base = api_base
|
default_api_base = api_base
|
||||||
default_api_key = api_key
|
default_api_key = api_key
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
if (
|
||||||
model_name=model_name,
|
model_name in litellm.open_ai_chat_completion_models
|
||||||
custom_llm_provider=custom_llm_provider,
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
or custom_llm_provider == "azure"
|
||||||
|
or custom_llm_provider == "azure_text"
|
||||||
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider == "text-completion-openai"
|
||||||
|
or "ft:gpt-3.5-turbo" in model_name
|
||||||
|
or model_name in litellm.open_ai_embedding_models
|
||||||
):
|
):
|
||||||
client_initialization_params = (
|
is_azure_ai_studio_model: bool = False
|
||||||
InitalizeOpenAISDKClient._get_client_initialization_params(
|
if custom_llm_provider == "azure":
|
||||||
model=model,
|
if litellm.utils._is_non_openai_azure_model(model_name):
|
||||||
model_name=model_name,
|
is_azure_ai_studio_model = True
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider = "openai"
|
||||||
litellm_params=litellm_params,
|
# remove azure prefx from model_name
|
||||||
default_api_key=default_api_key,
|
model_name = model_name.replace("azure/", "")
|
||||||
default_api_base=default_api_base,
|
# glorified / complicated reading of configs
|
||||||
)
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
)
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
|
api_key = litellm_params.get("api_key") or default_api_key
|
||||||
|
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
||||||
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
|
api_key = get_secret_str(api_key_env_name)
|
||||||
|
litellm_params["api_key"] = api_key
|
||||||
|
|
||||||
############### Unpack client initialization params #######################
|
api_base = litellm_params.get("api_base")
|
||||||
api_key = client_initialization_params.api_key
|
base_url: Optional[str] = litellm_params.get("base_url")
|
||||||
api_base = client_initialization_params.api_base
|
api_base = (
|
||||||
api_version: Optional[str] = client_initialization_params.api_version
|
api_base or base_url or default_api_base
|
||||||
timeout: Optional[Union[float, httpxTimeout]] = (
|
) # allow users to pass in `api_base` or `base_url` for azure
|
||||||
client_initialization_params.timeout
|
if api_base and api_base.startswith("os.environ/"):
|
||||||
|
api_base_env_name = api_base.replace("os.environ/", "")
|
||||||
|
api_base = get_secret_str(api_base_env_name)
|
||||||
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
|
## AZURE AI STUDIO MISTRAL CHECK ##
|
||||||
|
"""
|
||||||
|
Make sure api base ends in /v1/
|
||||||
|
|
||||||
|
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
is_azure_ai_studio_model is True
|
||||||
|
and api_base is not None
|
||||||
|
and isinstance(api_base, str)
|
||||||
|
and not api_base.endswith("/v1/")
|
||||||
|
):
|
||||||
|
# check if it ends with a trailing slash
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base += "v1/"
|
||||||
|
elif api_base.endswith("/v1"):
|
||||||
|
api_base += "/"
|
||||||
|
else:
|
||||||
|
api_base += "/v1/"
|
||||||
|
|
||||||
|
api_version = litellm_params.get("api_version")
|
||||||
|
if api_version and api_version.startswith("os.environ/"):
|
||||||
|
api_version_env_name = api_version.replace("os.environ/", "")
|
||||||
|
api_version = get_secret_str(api_version_env_name)
|
||||||
|
litellm_params["api_version"] = api_version
|
||||||
|
|
||||||
|
timeout: Optional[float] = (
|
||||||
|
litellm_params.pop("timeout", None) or litellm.request_timeout
|
||||||
)
|
)
|
||||||
stream_timeout: Optional[Union[float, httpxTimeout]] = (
|
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||||
client_initialization_params.stream_timeout
|
timeout_env_name = timeout.replace("os.environ/", "")
|
||||||
|
timeout = get_secret(timeout_env_name) # type: ignore
|
||||||
|
litellm_params["timeout"] = timeout
|
||||||
|
|
||||||
|
stream_timeout: Optional[float] = litellm_params.pop(
|
||||||
|
"stream_timeout", timeout
|
||||||
|
) # if no stream_timeout is set, default to timeout
|
||||||
|
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||||
|
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||||
|
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
||||||
|
litellm_params["stream_timeout"] = stream_timeout
|
||||||
|
|
||||||
|
max_retries: Optional[int] = litellm_params.pop(
|
||||||
|
"max_retries", 0
|
||||||
|
) # router handles retry logic
|
||||||
|
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||||
|
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||||
|
max_retries = get_secret(max_retries_env_name) # type: ignore
|
||||||
|
litellm_params["max_retries"] = max_retries
|
||||||
|
|
||||||
|
organization = litellm_params.get("organization", None)
|
||||||
|
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||||
|
organization_env_name = organization.replace("os.environ/", "")
|
||||||
|
organization = get_secret_str(organization_env_name)
|
||||||
|
litellm_params["organization"] = organization
|
||||||
|
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||||
|
if litellm_params.get("tenant_id"):
|
||||||
|
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||||
|
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||||
|
tenant_id=litellm_params.get("tenant_id"),
|
||||||
|
client_id=litellm_params.get("client_id"),
|
||||||
|
client_secret=litellm_params.get("client_secret"),
|
||||||
)
|
)
|
||||||
max_retries: int = client_initialization_params.max_retries
|
|
||||||
organization: Optional[str] = client_initialization_params.organization
|
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]] = (
|
|
||||||
client_initialization_params.azure_ad_token_provider
|
|
||||||
)
|
|
||||||
custom_llm_provider = client_initialization_params.custom_llm_provider
|
|
||||||
model_name = client_initialization_params.model_name
|
|
||||||
##########################################################################
|
|
||||||
|
|
||||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||||
if api_base is None or not isinstance(api_base, str):
|
if api_base is None or not isinstance(api_base, str):
|
||||||
filtered_litellm_params = {
|
filtered_litellm_params = {
|
||||||
k: v
|
k: v for k, v in model["litellm_params"].items() if k != "api_key"
|
||||||
for k, v in model["litellm_params"].items()
|
|
||||||
if k != "api_key"
|
|
||||||
}
|
}
|
||||||
_filtered_model = {
|
_filtered_model = {
|
||||||
"model_name": model["model_name"],
|
"model_name": model["model_name"],
|
||||||
|
@ -196,8 +232,8 @@ class InitalizeOpenAISDKClient:
|
||||||
azure_ad_token_provider=azure_ad_token_provider,
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -212,7 +248,7 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
|
@ -222,8 +258,8 @@ class InitalizeOpenAISDKClient:
|
||||||
azure_ad_token_provider=azure_ad_token_provider,
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -245,8 +281,8 @@ class InitalizeOpenAISDKClient:
|
||||||
azure_ad_token_provider=azure_ad_token_provider,
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -261,7 +297,7 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
@ -271,8 +307,8 @@ class InitalizeOpenAISDKClient:
|
||||||
azure_ad_token_provider=azure_ad_token_provider,
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -319,8 +355,8 @@ class InitalizeOpenAISDKClient:
|
||||||
cache_key = f"{model_id}_async_client"
|
cache_key = f"{model_id}_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
**azure_client_params,
|
**azure_client_params,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -334,14 +370,14 @@ class InitalizeOpenAISDKClient:
|
||||||
ttl=client_ttl,
|
ttl=client_ttl,
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
**azure_client_params,
|
**azure_client_params,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -360,8 +396,8 @@ class InitalizeOpenAISDKClient:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
**azure_client_params,
|
**azure_client_params,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -376,14 +412,14 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
**azure_client_params,
|
**azure_client_params,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=1000, max_keepalive_connections=100
|
max_connections=1000, max_keepalive_connections=100
|
||||||
|
@ -410,8 +446,8 @@ class InitalizeOpenAISDKClient:
|
||||||
_client = openai.AsyncOpenAI( # type: ignore
|
_client = openai.AsyncOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -427,15 +463,15 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
_client = openai.OpenAI( # type: ignore
|
_client = openai.OpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -456,8 +492,8 @@ class InitalizeOpenAISDKClient:
|
||||||
_client = openai.AsyncOpenAI( # type: ignore
|
_client = openai.AsyncOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -473,7 +509,7 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
if should_initialize_sync_client(
|
||||||
litellm_router_instance=litellm_router_instance
|
litellm_router_instance=litellm_router_instance
|
||||||
):
|
):
|
||||||
# streaming clients should have diff timeouts
|
# streaming clients should have diff timeouts
|
||||||
|
@ -481,8 +517,8 @@ class InitalizeOpenAISDKClient:
|
||||||
_client = openai.OpenAI( # type: ignore
|
_client = openai.OpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout, # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries, # type: ignore
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -498,157 +534,9 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_client_initialization_params(
|
|
||||||
model: dict,
|
|
||||||
model_name: str,
|
|
||||||
custom_llm_provider: Optional[str],
|
|
||||||
litellm_params: dict,
|
|
||||||
default_api_key: Optional[str],
|
|
||||||
default_api_base: Optional[str],
|
|
||||||
) -> OpenAISDKClientInitializationParams:
|
|
||||||
"""
|
|
||||||
Returns params to use for initializing OpenAI SDK clients (for OpenAI, Azure OpenAI, OpenAI Compatible Providers)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: model dict
|
|
||||||
model_name: model name
|
|
||||||
custom_llm_provider: custom llm provider
|
|
||||||
litellm_params: litellm params
|
|
||||||
default_api_key: default api key
|
|
||||||
default_api_base: default api base
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAISDKClientInitializationParams
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_azure_ai_studio_model: bool = False
|
|
||||||
if custom_llm_provider == "azure":
|
|
||||||
if litellm.utils._is_non_openai_azure_model(model_name):
|
|
||||||
is_azure_ai_studio_model = True
|
|
||||||
custom_llm_provider = "openai"
|
|
||||||
# remove azure prefx from model_name
|
|
||||||
model_name = model_name.replace("azure/", "")
|
|
||||||
# glorified / complicated reading of configs
|
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
|
||||||
api_key = litellm_params.get("api_key") or default_api_key
|
|
||||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
|
||||||
api_key = get_secret_str(api_key)
|
|
||||||
litellm_params["api_key"] = api_key
|
|
||||||
|
|
||||||
api_base = litellm_params.get("api_base")
|
|
||||||
base_url: Optional[str] = litellm_params.get("base_url")
|
|
||||||
api_base = (
|
|
||||||
api_base or base_url or default_api_base
|
|
||||||
) # allow users to pass in `api_base` or `base_url` for azure
|
|
||||||
if api_base and api_base.startswith("os.environ/"):
|
|
||||||
api_base = get_secret_str(api_base)
|
|
||||||
litellm_params["api_base"] = api_base
|
|
||||||
|
|
||||||
## AZURE AI STUDIO MISTRAL CHECK ##
|
|
||||||
"""
|
|
||||||
Make sure api base ends in /v1/
|
|
||||||
|
|
||||||
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
is_azure_ai_studio_model is True
|
|
||||||
and api_base is not None
|
|
||||||
and isinstance(api_base, str)
|
|
||||||
and not api_base.endswith("/v1/")
|
|
||||||
):
|
|
||||||
# check if it ends with a trailing slash
|
|
||||||
if api_base.endswith("/"):
|
|
||||||
api_base += "v1/"
|
|
||||||
elif api_base.endswith("/v1"):
|
|
||||||
api_base += "/"
|
|
||||||
else:
|
|
||||||
api_base += "/v1/"
|
|
||||||
|
|
||||||
api_version = litellm_params.get("api_version")
|
|
||||||
if api_version and api_version.startswith("os.environ/"):
|
|
||||||
api_version = get_secret_str(api_version)
|
|
||||||
litellm_params["api_version"] = api_version
|
|
||||||
|
|
||||||
timeout: Optional[Union[float, str, httpxTimeout]] = (
|
|
||||||
litellm_params.pop("timeout", None) or litellm.request_timeout
|
|
||||||
)
|
|
||||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
|
||||||
timeout = float(get_secret(timeout)) # type: ignore
|
|
||||||
litellm_params["timeout"] = timeout
|
|
||||||
|
|
||||||
stream_timeout: Optional[Union[float, str, httpxTimeout]] = litellm_params.pop(
|
|
||||||
"stream_timeout", timeout
|
|
||||||
) # if no stream_timeout is set, default to timeout
|
|
||||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
|
||||||
stream_timeout = float(get_secret(stream_timeout)) # type: ignore
|
|
||||||
litellm_params["stream_timeout"] = stream_timeout
|
|
||||||
|
|
||||||
max_retries: Optional[int] = litellm_params.pop(
|
|
||||||
"max_retries", 0
|
|
||||||
) # router handles retry logic
|
|
||||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
|
||||||
max_retries = get_secret(max_retries) # type: ignore
|
|
||||||
litellm_params["max_retries"] = max_retries
|
|
||||||
|
|
||||||
organization = litellm_params.get("organization", None)
|
|
||||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
|
||||||
organization = get_secret_str(organization)
|
|
||||||
litellm_params["organization"] = organization
|
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
|
||||||
tenant_id = litellm_params.get("tenant_id")
|
|
||||||
if tenant_id is not None:
|
|
||||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
|
||||||
azure_ad_token_provider = (
|
|
||||||
InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
client_id=litellm_params.get("client_id"),
|
|
||||||
client_secret=litellm_params.get("client_secret"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OpenAISDKClientInitializationParams(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token_provider=azure_ad_token_provider,
|
|
||||||
timeout=timeout, # type: ignore
|
|
||||||
stream_timeout=stream_timeout, # type: ignore
|
|
||||||
max_retries=max_retries, # type: ignore
|
|
||||||
organization=organization,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _should_create_openai_sdk_client_for_model(
|
|
||||||
model_name: str,
|
|
||||||
custom_llm_provider: str,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Returns True if a OpenAI SDK client should be created for a given model
|
|
||||||
|
|
||||||
We need a OpenAI SDK client for models that are callsed using OpenAI Python SDK
|
|
||||||
Azure OpenAI, OpenAI, OpenAI Compatible Providers, OpenAI Embedding Models
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
model_name in litellm.open_ai_chat_completion_models
|
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
|
||||||
or custom_llm_provider == "azure"
|
|
||||||
or custom_llm_provider == "azure_text"
|
|
||||||
or custom_llm_provider == "custom_openai"
|
|
||||||
or custom_llm_provider == "openai"
|
|
||||||
or custom_llm_provider == "text-completion-openai"
|
|
||||||
or "ft:gpt-3.5-turbo" in model_name
|
|
||||||
or model_name in litellm.open_ai_embedding_models
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_azure_ad_token_from_entrata_id(
|
def get_azure_ad_token_from_entrata_id(
|
||||||
tenant_id: str, client_id: Optional[str], client_secret: Optional[str]
|
tenant_id: str, client_id: str, client_secret: str
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
from azure.identity import (
|
from azure.identity import (
|
||||||
ClientSecretCredential,
|
ClientSecretCredential,
|
||||||
|
@ -656,11 +544,6 @@ class InitalizeOpenAISDKClient:
|
||||||
get_bearer_token_provider,
|
get_bearer_token_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client_id is None or client_secret is None:
|
|
||||||
raise ValueError(
|
|
||||||
"client_id and client_secret must be provided when using `tenant_id`"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
|
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
|
||||||
|
|
||||||
if tenant_id.startswith("os.environ/"):
|
if tenant_id.startswith("os.environ/"):
|
||||||
|
|
|
@ -17,10 +17,6 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.router_utils.client_initalization_utils import (
|
|
||||||
InitalizeOpenAISDKClient,
|
|
||||||
OpenAISDKClientInitializationParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -700,283 +696,3 @@ def test_init_router_with_supported_environments(environment, expected_models):
|
||||||
assert set(_model_list) == set(expected_models)
|
assert set(_model_list) == set(expected_models)
|
||||||
|
|
||||||
os.environ.pop("LITELLM_ENVIRONMENT")
|
os.environ.pop("LITELLM_ENVIRONMENT")
|
||||||
|
|
||||||
|
|
||||||
def test_should_initialize_sync_client():
|
|
||||||
from litellm.types.router import RouterGeneralSettings
|
|
||||||
|
|
||||||
# Test case 1: Router instance is None
|
|
||||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(None) is False
|
|
||||||
|
|
||||||
# Test case 2: Router instance without router_general_settings
|
|
||||||
router = Router(model_list=[])
|
|
||||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
|
||||||
|
|
||||||
# Test case 3: Router instance with async_only_mode = False
|
|
||||||
router = Router(
|
|
||||||
model_list=[],
|
|
||||||
router_general_settings=RouterGeneralSettings(async_only_mode=False),
|
|
||||||
)
|
|
||||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
|
||||||
|
|
||||||
# Test case 4: Router instance with async_only_mode = True
|
|
||||||
router = Router(
|
|
||||||
model_list=[],
|
|
||||||
router_general_settings=RouterGeneralSettings(async_only_mode=True),
|
|
||||||
)
|
|
||||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is False
|
|
||||||
|
|
||||||
# Test case 5: Router instance with router_general_settings but without async_only_mode
|
|
||||||
router = Router(model_list=[], router_general_settings=RouterGeneralSettings())
|
|
||||||
assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True
|
|
||||||
|
|
||||||
print("All test cases passed!")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name, custom_llm_provider, expected_result",
|
|
||||||
[
|
|
||||||
("gpt-3.5-turbo", None, True), # OpenAI chat completion model
|
|
||||||
("text-embedding-ada-002", None, True), # OpenAI embedding model
|
|
||||||
("claude-2", None, False), # Non-OpenAI model
|
|
||||||
("gpt-3.5-turbo", "azure", True), # Azure OpenAI
|
|
||||||
("text-davinci-003", "azure_text", True), # Azure OpenAI
|
|
||||||
("gpt-3.5-turbo", "openai", True), # OpenAI
|
|
||||||
("custom-model", "custom_openai", True), # Custom OpenAI compatible
|
|
||||||
("text-davinci-003", "text-completion-openai", True), # OpenAI text completion
|
|
||||||
(
|
|
||||||
"ft:gpt-3.5-turbo-0613:my-org:custom-model:7p4lURel",
|
|
||||||
None,
|
|
||||||
True,
|
|
||||||
), # Fine-tuned GPT model
|
|
||||||
("mistral-7b", "huggingface", False), # Non-OpenAI provider
|
|
||||||
("custom-model", "anthropic", False), # Non-OpenAI compatible provider
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_should_create_openai_sdk_client_for_model(
|
|
||||||
model_name, custom_llm_provider, expected_result
|
|
||||||
):
|
|
||||||
result = InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
|
||||||
model_name, custom_llm_provider
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result == expected_result
|
|
||||||
), f"Failed for model: {model_name}, provider: {custom_llm_provider}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_should_create_openai_sdk_client_for_model_openai_compatible_providers():
|
|
||||||
# Test with a known OpenAI compatible provider
|
|
||||||
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
|
||||||
"custom-model", "groq"
|
|
||||||
), "Should return True for OpenAI compatible provider"
|
|
||||||
|
|
||||||
# Add a new compatible provider and test
|
|
||||||
litellm.openai_compatible_providers.append("new_provider")
|
|
||||||
assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model(
|
|
||||||
"custom-model", "new_provider"
|
|
||||||
), "Should return True for newly added OpenAI compatible provider"
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
litellm.openai_compatible_providers.remove("new_provider")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_openai():
|
|
||||||
"""Test basic OpenAI configuration with direct parameter passing."""
|
|
||||||
model = {}
|
|
||||||
model_name = "gpt-3.5-turbo"
|
|
||||||
custom_llm_provider = None
|
|
||||||
litellm_params = {"api_key": "sk-openai-key", "timeout": 30, "max_retries": 3}
|
|
||||||
default_api_key = None
|
|
||||||
default_api_base = None
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model=model,
|
|
||||||
model_name=model_name,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
default_api_key=default_api_key,
|
|
||||||
default_api_base=default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, OpenAISDKClientInitializationParams)
|
|
||||||
assert result.api_key == "sk-openai-key"
|
|
||||||
assert result.timeout == 30
|
|
||||||
assert result.max_retries == 3
|
|
||||||
assert result.model_name == "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_azure():
|
|
||||||
"""Test Azure OpenAI configuration with specific Azure parameters."""
|
|
||||||
model = {}
|
|
||||||
model_name = "azure/gpt-4"
|
|
||||||
custom_llm_provider = "azure"
|
|
||||||
litellm_params = {
|
|
||||||
"api_key": "azure-key",
|
|
||||||
"api_base": "https://example.azure.openai.com",
|
|
||||||
"api_version": "2023-05-15",
|
|
||||||
}
|
|
||||||
default_api_key = None
|
|
||||||
default_api_base = None
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model=model,
|
|
||||||
model_name=model_name,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
default_api_key=default_api_key,
|
|
||||||
default_api_base=default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.api_key == "azure-key"
|
|
||||||
assert result.api_base == "https://example.azure.openai.com"
|
|
||||||
assert result.api_version == "2023-05-15"
|
|
||||||
assert result.custom_llm_provider == "azure"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_environment_variable_parsing():
|
|
||||||
"""Test parsing of environment variables for configuration."""
|
|
||||||
os.environ["UNIQUE_OPENAI_API_KEY"] = "env-openai-key"
|
|
||||||
os.environ["UNIQUE_TIMEOUT"] = "45"
|
|
||||||
|
|
||||||
model = {}
|
|
||||||
model_name = "gpt-4"
|
|
||||||
custom_llm_provider = None
|
|
||||||
litellm_params = {
|
|
||||||
"api_key": "os.environ/UNIQUE_OPENAI_API_KEY",
|
|
||||||
"timeout": "os.environ/UNIQUE_TIMEOUT",
|
|
||||||
"organization": "os.environ/UNIQUE_ORG_ID",
|
|
||||||
}
|
|
||||||
default_api_key = None
|
|
||||||
default_api_base = None
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model=model,
|
|
||||||
model_name=model_name,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
default_api_key=default_api_key,
|
|
||||||
default_api_base=default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.api_key == "env-openai-key"
|
|
||||||
assert result.timeout == 45.0
|
|
||||||
assert result.organization is None # Since ORG_ID is not set in the environment
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_azure_ai_studio_mistral():
|
|
||||||
"""
|
|
||||||
Test configuration for Azure AI Studio Mistral model.
|
|
||||||
|
|
||||||
- /v1/ is added to the api_base if it is not present
|
|
||||||
- custom_llm_provider is set to openai (Azure AI Studio Mistral models need to use OpenAI route)
|
|
||||||
"""
|
|
||||||
|
|
||||||
model = {}
|
|
||||||
model_name = "azure/mistral-large-latest"
|
|
||||||
custom_llm_provider = "azure"
|
|
||||||
litellm_params = {
|
|
||||||
"api_key": "azure-key",
|
|
||||||
"api_base": "https://example.azure.openai.com",
|
|
||||||
}
|
|
||||||
default_api_key = None
|
|
||||||
default_api_base = None
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model,
|
|
||||||
model_name,
|
|
||||||
custom_llm_provider,
|
|
||||||
litellm_params,
|
|
||||||
default_api_key,
|
|
||||||
default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.custom_llm_provider == "openai"
|
|
||||||
assert result.model_name == "mistral-large-latest"
|
|
||||||
assert result.api_base == "https://example.azure.openai.com/v1/"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_default_values():
|
|
||||||
"""
|
|
||||||
Test use of default values when specific parameters are not provided.
|
|
||||||
|
|
||||||
This is used typically for OpenAI compatible providers - example Together AI
|
|
||||||
|
|
||||||
"""
|
|
||||||
model = {}
|
|
||||||
model_name = "together/meta-llama-3.1-8b-instruct"
|
|
||||||
custom_llm_provider = None
|
|
||||||
litellm_params = {}
|
|
||||||
default_api_key = "together-api-key"
|
|
||||||
default_api_base = "https://together.xyz/api.openai.com"
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model=model,
|
|
||||||
model_name=model_name,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
default_api_key=default_api_key,
|
|
||||||
default_api_base=default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.api_key == "together-api-key"
|
|
||||||
assert result.api_base == "https://together.xyz/api.openai.com"
|
|
||||||
assert result.timeout == litellm.request_timeout
|
|
||||||
assert result.max_retries == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_client_initialization_params_all_env_vars():
|
|
||||||
# Set up environment variables
|
|
||||||
os.environ["TEST_API_KEY"] = "test-api-key"
|
|
||||||
os.environ["TEST_API_BASE"] = "https://test.openai.com"
|
|
||||||
os.environ["TEST_API_VERSION"] = "2023-05-15"
|
|
||||||
os.environ["TEST_TIMEOUT"] = "30"
|
|
||||||
os.environ["TEST_STREAM_TIMEOUT"] = "60"
|
|
||||||
os.environ["TEST_MAX_RETRIES"] = "3"
|
|
||||||
os.environ["TEST_ORGANIZATION"] = "test-org"
|
|
||||||
|
|
||||||
model = {}
|
|
||||||
model_name = "gpt-4"
|
|
||||||
custom_llm_provider = None
|
|
||||||
litellm_params = {
|
|
||||||
"api_key": "os.environ/TEST_API_KEY",
|
|
||||||
"api_base": "os.environ/TEST_API_BASE",
|
|
||||||
"api_version": "os.environ/TEST_API_VERSION",
|
|
||||||
"timeout": "os.environ/TEST_TIMEOUT",
|
|
||||||
"stream_timeout": "os.environ/TEST_STREAM_TIMEOUT",
|
|
||||||
"max_retries": "os.environ/TEST_MAX_RETRIES",
|
|
||||||
"organization": "os.environ/TEST_ORGANIZATION",
|
|
||||||
}
|
|
||||||
default_api_key = None
|
|
||||||
default_api_base = None
|
|
||||||
|
|
||||||
result = InitalizeOpenAISDKClient._get_client_initialization_params(
|
|
||||||
model=model,
|
|
||||||
model_name=model_name,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
default_api_key=default_api_key,
|
|
||||||
default_api_base=default_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, OpenAISDKClientInitializationParams)
|
|
||||||
assert result.api_key == "test-api-key"
|
|
||||||
assert result.api_base == "https://test.openai.com"
|
|
||||||
assert result.api_version == "2023-05-15"
|
|
||||||
assert result.timeout == 30.0
|
|
||||||
assert result.stream_timeout == 60.0
|
|
||||||
assert result.max_retries == 3
|
|
||||||
assert result.organization == "test-org"
|
|
||||||
assert result.model_name == "gpt-4"
|
|
||||||
assert result.custom_llm_provider is None
|
|
||||||
|
|
||||||
# Clean up environment variables
|
|
||||||
for key in [
|
|
||||||
"TEST_API_KEY",
|
|
||||||
"TEST_API_BASE",
|
|
||||||
"TEST_API_VERSION",
|
|
||||||
"TEST_TIMEOUT",
|
|
||||||
"TEST_STREAM_TIMEOUT",
|
|
||||||
"TEST_MAX_RETRIES",
|
|
||||||
"TEST_ORGANIZATION",
|
|
||||||
]:
|
|
||||||
os.environ.pop(key)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue