import json import os import sys import pytest from fastapi.testclient import TestClient sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path import pytest from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient import litellm from litellm.proxy._types import SpecialHeaders from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMiddleware # Fake auth functions to simulate valid and invalid auth behavior. async def fake_valid_auth(request, api_key): # Simulate valid authentication: do nothing (i.e. pass) return async def fake_invalid_auth(request, api_key): print("running fake invalid auth", request, api_key) # Simulate invalid auth by raising an exception. raise Exception("Invalid API key") from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @pytest.fixture def app_with_middleware(): """Create a FastAPI app with the PrometheusAuthMiddleware and dummy endpoints.""" app = FastAPI() # Add the PrometheusAuthMiddleware to the app. app.add_middleware(PrometheusAuthMiddleware) @app.get("/metrics") async def metrics(): return {"msg": "metrics OK"} # Also allow /metrics/ (trailing slash) @app.get("/metrics/") async def metrics_slash(): return {"msg": "metrics OK"} @app.get("/chat/completions") async def chat(): return {"msg": "chat completions OK"} @app.get("/embeddings") async def embeddings(): return {"msg": "embeddings OK"} return app def test_valid_auth_metrics(app_with_middleware, monkeypatch): """ Test that a request to /metrics (and /metrics/) with valid auth headers passes. """ # Enable auth on metrics endpoints. litellm.require_auth_for_metrics_endpoint = True # Patch the auth function to simulate a valid authentication. monkeypatch.setattr( "litellm.proxy.middleware.prometheus_auth_middleware.user_api_key_auth", fake_valid_auth, ) client = TestClient(app_with_middleware) headers = {SpecialHeaders.openai_authorization.value: "valid"} # Test for /metrics (no trailing slash) response = client.get("/metrics", headers=headers) assert response.status_code == 200, response.text assert response.json() == {"msg": "metrics OK"} # Test for /metrics/ (with trailing slash) response = client.get("/metrics/", headers=headers) assert response.status_code == 200, response.text assert response.json() == {"msg": "metrics OK"} def test_invalid_auth_metrics(app_with_middleware, monkeypatch): """ Test that a request to /metrics with invalid auth headers fails with a 401. """ litellm.require_auth_for_metrics_endpoint = True # Patch the auth function to simulate a failed authentication. monkeypatch.setattr( "litellm.proxy.middleware.prometheus_auth_middleware.user_api_key_auth", fake_invalid_auth, ) client = TestClient(app_with_middleware) headers = {SpecialHeaders.openai_authorization.value: "invalid"} response = client.get("/metrics", headers=headers) assert response.status_code == 401, response.text assert "Unauthorized access to metrics endpoint" in response.text def test_no_auth_metrics_when_disabled(app_with_middleware, monkeypatch): """ Test that when require_auth_for_metrics_endpoint is False, requests to /metrics bypass the auth check. """ litellm.require_auth_for_metrics_endpoint = False # To ensure auth is not run, patch the auth function with one that will raise if called. def should_not_be_called(*args, **kwargs): raise Exception("Auth should not be called") monkeypatch.setattr( "litellm.proxy.middleware.prometheus_auth_middleware.user_api_key_auth", should_not_be_called, ) client = TestClient(app_with_middleware) response = client.get("/metrics") assert response.status_code == 200, response.text assert response.json() == {"msg": "metrics OK"}