forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_default_router_retries
This commit is contained in:
commit
1a06f009d1
20 changed files with 1663 additions and 44 deletions
|
@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
import os, httpx
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -56,6 +57,87 @@ def test_router_num_retries_init(num_retries, max_retries):
|
|||
else:
|
||||
assert getattr(model_client, "max_retries") == 0
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
|
||||
)
|
||||
@pytest.mark.parametrize("ssl_verify", [True, False])
|
||||
def test_router_timeout_init(timeout, ssl_verify):
|
||||
"""
|
||||
Allow user to pass httpx.Timeout
|
||||
|
||||
related issue - https://github.com/BerriAI/litellm/issues/3162
|
||||
"""
|
||||
litellm.ssl_verify = ssl_verify
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"timeout": timeout,
|
||||
},
|
||||
"model_info": {"id": 1234},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
model_client = router._get_client(
|
||||
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
|
||||
)
|
||||
|
||||
assert getattr(model_client, "timeout") == timeout
|
||||
|
||||
print(f"vars model_client: {vars(model_client)}")
|
||||
http_client = getattr(model_client, "_client")
|
||||
print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}")
|
||||
if ssl_verify == False:
|
||||
assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE"
|
||||
else:
|
||||
assert (
|
||||
http_client._transport._pool._ssl_context.verify_mode.name
|
||||
== "CERT_REQUIRED"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mistral_api_base",
|
||||
[
|
||||
"os.environ/AZURE_MISTRAL_API_BASE",
|
||||
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/",
|
||||
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1",
|
||||
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/",
|
||||
"https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com",
|
||||
],
|
||||
)
|
||||
def test_router_azure_ai_studio_init(mistral_api_base):
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/mistral-large-latest",
|
||||
"api_key": "os.environ/AZURE_MISTRAL_API_KEY",
|
||||
"api_base": mistral_api_base,
|
||||
},
|
||||
"model_info": {"id": 1234},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
model_client = router._get_client(
|
||||
deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={}
|
||||
)
|
||||
url = getattr(model_client, "_base_url")
|
||||
uri_reference = str(getattr(url, "_uri_reference"))
|
||||
|
||||
print(f"uri_reference: {uri_reference}")
|
||||
|
||||
assert "/v1/" in uri_reference
|
||||
|
||||
|
||||
def test_exception_raising():
|
||||
# this tests if the router raises an exception when invalid params are set
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue