mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'litellm_dev_03_10_2025_p3' into litellm_router_client_init_migration
This commit is contained in:
commit
d2fde6b5aa
3 changed files with 21 additions and 94 deletions
|
@ -5345,6 +5345,13 @@ class Router:
|
||||||
client = self.cache.get_cache(
|
client = self.cache.get_cache(
|
||||||
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
|
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
|
||||||
)
|
)
|
||||||
|
if client is None:
|
||||||
|
InitalizeOpenAISDKClient.set_max_parallel_requests_client(
|
||||||
|
litellm_router_instance=self, model=deployment
|
||||||
|
)
|
||||||
|
client = self.cache.get_cache(
|
||||||
|
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
elif client_type == "async":
|
elif client_type == "async":
|
||||||
if kwargs.get("stream") is True:
|
if kwargs.get("stream") is True:
|
||||||
|
|
|
@ -45,18 +45,11 @@ class InitalizeOpenAISDKClient:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_client( # noqa: PLR0915
|
def set_max_parallel_requests_client(
|
||||||
litellm_router_instance: LitellmRouter, model: dict
|
litellm_router_instance: LitellmRouter, model: dict
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
|
||||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
|
||||||
"""
|
|
||||||
client_ttl = litellm_router_instance.client_ttl
|
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = model.get("litellm_params", {})
|
||||||
model_name = litellm_params.get("model")
|
|
||||||
model_id = model["model_info"]["id"]
|
model_id = model["model_info"]["id"]
|
||||||
# ### IF RPM SET - initialize a semaphore ###
|
|
||||||
rpm = litellm_params.get("rpm", None)
|
rpm = litellm_params.get("rpm", None)
|
||||||
tpm = litellm_params.get("tpm", None)
|
tpm = litellm_params.get("tpm", None)
|
||||||
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
||||||
|
@ -75,6 +68,19 @@ class InitalizeOpenAISDKClient:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_client( # noqa: PLR0915
|
||||||
|
litellm_router_instance: LitellmRouter, model: dict
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
|
"""
|
||||||
|
client_ttl = litellm_router_instance.client_ttl
|
||||||
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
model_name = litellm_params.get("model")
|
||||||
|
model_id = model["model_info"]["id"]
|
||||||
|
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
@ -185,7 +191,6 @@ class InitalizeOpenAISDKClient:
|
||||||
organization_env_name = organization.replace("os.environ/", "")
|
organization_env_name = organization.replace("os.environ/", "")
|
||||||
organization = get_secret_str(organization_env_name)
|
organization = get_secret_str(organization_env_name)
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_api_key = api_key # type: ignore
|
_api_key = api_key # type: ignore
|
||||||
if _api_key is not None and isinstance(_api_key, str):
|
if _api_key is not None and isinstance(_api_key, str):
|
||||||
|
|
|
@ -320,91 +320,6 @@ def test_router_order():
|
||||||
assert response._hidden_params["model_id"] == "1"
|
assert response._hidden_params["model_id"] == "1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_retries", [None, 2])
|
|
||||||
@pytest.mark.parametrize("max_retries", [None, 4])
|
|
||||||
def test_router_num_retries_init(num_retries, max_retries):
|
|
||||||
"""
|
|
||||||
- test when num_retries set v/s not
|
|
||||||
- test client value when max retries set v/s not
|
|
||||||
"""
|
|
||||||
router = Router(
|
|
||||||
model_list=[
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
|
||||||
"max_retries": max_retries,
|
|
||||||
},
|
|
||||||
"model_info": {"id": 12345},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
num_retries=num_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_retries is not None:
|
|
||||||
assert router.num_retries == num_retries
|
|
||||||
else:
|
|
||||||
assert router.num_retries == openai.DEFAULT_MAX_RETRIES
|
|
||||||
|
|
||||||
model_client = router._get_client(
|
|
||||||
{"model_info": {"id": 12345}}, client_type="async", kwargs={}
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_retries is not None:
|
|
||||||
assert getattr(model_client, "max_retries") == max_retries
|
|
||||||
else:
|
|
||||||
assert getattr(model_client, "max_retries") == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("ssl_verify", [True, False])
|
|
||||||
def test_router_timeout_init(timeout, ssl_verify):
|
|
||||||
"""
|
|
||||||
Allow user to pass httpx.Timeout
|
|
||||||
|
|
||||||
related issue - https://github.com/BerriAI/litellm/issues/3162
|
|
||||||
"""
|
|
||||||
litellm.ssl_verify = ssl_verify
|
|
||||||
|
|
||||||
router = Router(
|
|
||||||
model_list=[
|
|
||||||
{
|
|
||||||
"model_name": "test-model",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"timeout": timeout,
|
|
||||||
},
|
|
||||||
"model_info": {"id": 1234},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
model_client = router._get_client(
|
|
||||||
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert getattr(model_client, "timeout") == timeout
|
|
||||||
|
|
||||||
print(f"vars model_client: {vars(model_client)}")
|
|
||||||
http_client = getattr(model_client, "_client")
|
|
||||||
print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}")
|
|
||||||
if ssl_verify == False:
|
|
||||||
assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
http_client._transport._pool._ssl_context.verify_mode.name
|
|
||||||
== "CERT_REQUIRED"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_retries(sync_mode):
|
async def test_router_retries(sync_mode):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue