diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0008239cf..08918925b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -234,7 +234,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap return UserAPIKeyAuth() if api_key is None: # only require api key if master key is set - raise Exception("No api key passed in.") + raise Exception(f"No api key passed in.") + route = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead @@ -816,11 +817,12 @@ async def startup_event(): @router.on_event("shutdown") async def shutdown_event(): - global prisma_client + global prisma_client, master_key, user_custom_auth if prisma_client: print("Disconnecting from Prisma") await prisma_client.disconnect() - + master_key = None + user_custom_auth = None #### API ENDPOINTS #### @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index b9017987c..ce962430b 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -23,7 +23,7 @@ logging.basicConfig( # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) @@ -31,8 +31,15 @@ app = FastAPI() app.include_router(router) # Include your router in the test app @app.on_event("startup") async def wrapper_startup_event(): - await startup_event() + initialize(config=config_fp) +# Your bearer token +token = os.getenv("PROXY_MASTER_KEY") + +headers = { + "Authorization": f"Bearer {token}" +} + # Here you create a fixture that will be used by your tests # Make sure the fixture returns TestClient(app) @pytest.fixture(autouse=True) @@ -41,6 +48,7 @@ def client(): yield client def test_chat_completion(client): + global headers try: # Your test data test_data = { @@ -53,8 +61,9 @@ def test_chat_completion(client): ], "max_tokens": 10, } + print("testing proxy server") - response = client.post("/v1/chat/completions", json=test_data) + response = client.post("/v1/chat/completions", json=test_data, headers=headers) print(f"response - {response.text}") assert response.status_code == 200 result = response.json() @@ -65,6 +74,7 @@ def test_chat_completion(client): # Run the test def test_chat_completion_azure(client): + global headers try: # Your test data test_data = { @@ -77,8 +87,9 @@ def test_chat_completion_azure(client): ], "max_tokens": 10, } + print("testing proxy server with Azure Request") - response = client.post("/v1/chat/completions", json=test_data) + response = client.post("/v1/chat/completions", json=test_data, headers=headers) assert response.status_code == 200 result = response.json() @@ -92,13 +103,14 @@ def test_chat_completion_azure(client): def test_embedding(client): + global headers try: test_data = { "model": "azure/azure-embedding-model", "input": ["good morning from litellm"], } print("testing proxy server with OpenAI embedding") - response = client.post("/v1/embeddings", json=test_data) + response = client.post("/v1/embeddings", json=test_data, headers=headers) assert response.status_code == 200 result = response.json() @@ -112,6 +124,7 @@ def test_embedding(client): def test_add_new_model(client): + global headers try: test_data = { "model_name": "test_openai_models", @@ -122,8 +135,8 @@ def test_add_new_model(client): "description": "this is a test openai model" } } - client.post("/model/new", json=test_data) - response = client.get("/model/info") + client.post("/model/new", json=test_data, headers=headers) + response = client.get("/model/info", headers=headers) assert response.status_code == 200 result = response.json() print(f"response: {result}") @@ -172,7 +185,7 @@ def test_chat_completion_optional_params(client): litellm.callbacks = [customHandler] print("testing proxy server: optional params") - response = client.post("/v1/chat/completions", json=test_data) + response = client.post("/v1/chat/completions", json=test_data, headers=headers) assert response.status_code == 200 result = response.json() print(f"Received response: {result}") diff --git a/litellm/utils.py b/litellm/utils.py index 0b2f5243d..b37108bfc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -554,6 +554,7 @@ class Logging: "litellm_params": self.litellm_params, "start_time": self.start_time, "stream": self.stream, + "user": user, **self.optional_params }