mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test(test_pass_through_endpoints.py): add test for rpm limit support
This commit is contained in:
parent
0cc273d77b
commit
55e153556a
2 changed files with 62 additions and 1 deletions
|
@ -96,7 +96,6 @@ async def user_api_key_auth(
|
|||
anthropic_api_key_header
|
||||
),
|
||||
) -> UserAPIKeyAuth:
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
allowed_routes_check,
|
||||
common_checks,
|
||||
|
|
|
@ -85,6 +85,68 @@ async def test_pass_through_endpoint_rerank(client):
|
|||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth, rpm_limit, expected_error_code",
|
||||
[(True, 0, 429), (True, 1, 200), (False, 0, 401)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_endpoint_rpm_limit(
|
||||
client, auth, expected_error_code, rpm_limit
|
||||
):
|
||||
import litellm
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
|
||||
|
||||
mock_api_key = "sk-my-test-key"
|
||||
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
||||
|
||||
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
||||
|
||||
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
proxy_logging_obj._init_litellm_callbacks()
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
||||
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
||||
|
||||
# Define a pass-through endpoint
|
||||
pass_through_endpoints = [
|
||||
{
|
||||
"path": "/v1/rerank",
|
||||
"target": "https://api.cohere.com/v1/rerank",
|
||||
"auth": auth,
|
||||
"headers": {"Authorization": f"bearer {_cohere_api_key}"},
|
||||
}
|
||||
]
|
||||
|
||||
# Initialize the pass-through endpoint
|
||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||
|
||||
_json_data = {
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of the United States?",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada."
|
||||
],
|
||||
}
|
||||
|
||||
# Make a request to the pass-through endpoint
|
||||
response = client.post(
|
||||
"/v1/rerank",
|
||||
json=_json_data,
|
||||
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
||||
)
|
||||
|
||||
print("JSON response: ", _json_data)
|
||||
|
||||
# Assert the response
|
||||
assert response.status_code == expected_error_code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_endpoint_anthropic(client):
|
||||
import litellm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue