test(test_pass_through_endpoints.py): correctly reset test

This commit is contained in:
Krrish Dholakia 2024-08-14 10:48:42 -07:00
parent 6bcfa90ba8
commit 5af9794b9d

View file

@ -172,83 +172,113 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li
async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
auth, expected_error_code, rpm_limit
):
client = TestClient(app)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
# Store original values
original_user_api_key_cache = getattr(
litellm.proxy.proxy_server, "user_api_key_cache", None
)
original_master_key = getattr(litellm.proxy.proxy_server, "master_key", None)
original_prisma_client = getattr(litellm.proxy.proxy_server, "prisma_client", None)
original_proxy_logging_obj = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj", None
)
_cohere_api_key = os.environ.get("COHERE_API_KEY")
try:
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(
token=hash_token(mock_api_key), rpm_limit=rpm_limit
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
_cohere_api_key = os.environ.get("COHERE_API_KEY")
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/api/public/ingestion",
"target": "https://cloud.langfuse.com/api/public/ingestion",
"auth": auth,
"custom_auth_parser": "langfuse",
"headers": {
"LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY",
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/api/public/ingestion",
"target": "https://us.cloud.langfuse.com/api/public/ingestion",
"auth": auth,
"custom_auth_parser": "langfuse",
"headers": {
"LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY",
},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
old_general_settings = general_settings
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"batch": [
{
"id": "80e2141f-0ca6-47b7-9c06-dde5e97de690",
"type": "trace-create",
"body": {
"id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865",
"timestamp": "2024-08-14T02:38:56.092950Z",
"name": "test-trace-litellm-proxy-passthrough",
},
"timestamp": "2024-08-14T02:38:56.093352Z",
}
],
"metadata": {
"batch_size": 1,
"sdk_integration": "default",
"sdk_name": "python",
"sdk_version": "2.27.0",
"public_key": "anything",
},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Make a request to the pass-through endpoint
response = client.post(
"/api/public/ingestion",
json=_json_data,
headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="},
)
_json_data = {
"batch": [
{
"id": "80e2141f-0ca6-47b7-9c06-dde5e97de690",
"type": "trace-create",
"body": {
"id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865",
"timestamp": "2024-08-14T02:38:56.092950Z",
"name": "test-trace-litellm-proxy-passthrough",
},
"timestamp": "2024-08-14T02:38:56.093352Z",
}
],
"metadata": {
"batch_size": 1,
"sdk_integration": "default",
"sdk_name": "python",
"sdk_version": "2.27.0",
"public_key": "anything",
},
}
print("JSON response: ", _json_data)
# Make a request to the pass-through endpoint
response = client.post(
"/api/public/ingestion",
json=_json_data,
headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="},
)
print("RESPONSE RECEIVED - {}".format(response.text))
print("JSON response: ", _json_data)
# Assert the response
assert response.status_code == expected_error_code
print("RESPONSE RECEIVED - {}".format(response.text))
# Assert the response
assert response.status_code == expected_error_code
setattr(litellm.proxy.proxy_server, "general_settings", old_general_settings)
finally:
# Reset to original values
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
original_user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "master_key", original_master_key)
setattr(litellm.proxy.proxy_server, "prisma_client", original_prisma_client)
setattr(
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
)
@pytest.mark.asyncio