diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6f8e0f6ab5..9d597ac01f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -168,7 +168,7 @@ def log_input_output(request, response, custom_logger=None): from typing import Dict -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) user_api_base = None user_model = None user_debug = False @@ -213,9 +213,13 @@ def usage_telemetry( -async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth: +async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: global master_key, prisma_client, llm_model_list, user_custom_auth try: + if isinstance(api_key, str): + assert api_key.startswith("Bearer ") # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token + print(f"api_key: {api_key}; master_key: {master_key}; user_custom_auth: {user_custom_auth}") ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth: response = await user_custom_auth(request=request, api_key=api_key) @@ -223,15 +227,16 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche if master_key is None: if isinstance(api_key, str): - return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", "")) - else: - return UserAPIKeyAuth() - if api_key is None: + return UserAPIKeyAuth(api_key=api_key) + else: + return UserAPIKeyAuth() + + if api_key is None: # only require api key if master key is set raise Exception("No api key passed in.") route = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead - is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) + is_master_key_valid = secrets.compare_digest(api_key, master_key) if is_master_key_valid: return UserAPIKeyAuth(api_key=master_key) @@ -241,9 +246,9 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche if prisma_client: ## check for cache hit (In-Memory Cache) valid_token = user_api_key_cache.get_cache(key=api_key) - if valid_token is None and "Bearer " in api_key: + if valid_token is None: ## check db - cleaned_api_key = api_key[len("Bearer "):] + cleaned_api_key = api_key valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow()) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) elif valid_token is not None: @@ -275,10 +280,10 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid user key", - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid user key", + ) def prisma_setup(database_url: Optional[str]): global prisma_client @@ -597,13 +602,17 @@ def initialize( config, use_queue ): - global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings + global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth generate_feedback_box() user_model = model user_debug = debug dynamic_config = {"general": {}, user_model: {}} if config: llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config) + else: + # reset auth if config not passed, needed for consecutive tests on proxy + master_key = None + user_custom_auth = None if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers @@ -810,7 +819,6 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key detail=error_msg ) - @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint diff --git a/litellm/tests/test_configs/test_config.yaml b/litellm/tests/test_configs/test_config.yaml index 34b3d928a4..fa2079666c 100644 --- a/litellm/tests/test_configs/test_config.yaml +++ b/litellm/tests/test_configs/test_config.yaml @@ -1,24 +1,27 @@ -model_list: - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-35-turbo" - api_key: "os.environ/AZURE_EUROPE_API_KEY" - api_base: "https://my-endpoint-europe-berri-992.openai.azure.com/" - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-35-turbo" - api_key: "os.environ/AZURE_CANADA_API_KEY" - api_base: "https://my-endpoint-canada-berri992.openai.azure.com" - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-turbo" - api_key: "os.environ/AZURE_FRANCE_API_KEY" - api_base: "https://openai-france-1234.openai.azure.com" - -litellm_settings: - drop_params: True - set_verbose: True - general_settings: - master_key: "os.environ/PROXY_MASTER_KEY" - database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy + database_url: os.environ/PROXY_DATABASE_URL + master_key: os.environ/PROXY_MASTER_KEY +litellm_settings: + drop_params: true + set_verbose: true +model_list: +- litellm_params: + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://my-endpoint-canada-berri992.openai.azure.com + api_key: os.environ/AZURE_CANADA_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://openai-france-1234.openai.azure.com + api_key: os.environ/AZURE_FRANCE_API_KEY + model: azure/gpt-turbo + model_name: azure-model +- litellm_params: + model: gpt-3.5-turbo + model_info: + description: this is a test openai model + model_name: test_openai_models diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index fa1b5f6dd7..5708b1c41c 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -18,7 +18,7 @@ from litellm import RateLimitError # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) @@ -26,7 +26,7 @@ app = FastAPI() app.include_router(router) # Include your router in the test app @app.on_event("startup") async def wrapper_startup_event(): - await startup_event() + initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) # Here you create a fixture that will be used by your tests # Make sure the fixture returns TestClient(app) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index a525f01bf0..b15ee83075 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -18,11 +18,22 @@ from litellm import RateLimitError # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +save_worker_config(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app -client = TestClient(app) -def test_chat_completion(): +@app.on_event("startup") +async def wrapper_startup_event(): # required to reset config on app init - b/c pytest collects across multiple files - which sets the fastapi client + WORKER CONFIG to whatever was collected last + initialize(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + +def test_chat_completion(client): try: # Your test data test_data = { @@ -37,18 +48,16 @@ def test_chat_completion(): } print("testing proxy server") response = client.post("/v1/chat/completions", json=test_data) - + print(f"response - {response.text}") assert response.status_code == 200 result = response.json() print(f"Received response: {result}") except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test -# test_chat_completion() - -def test_chat_completion_azure(): +def test_chat_completion_azure(client): try: # Your test data test_data = { @@ -69,13 +78,13 @@ def test_chat_completion_azure(): print(f"Received response: {result}") assert len(result["choices"][0]["message"]["content"]) > 0 except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test # test_chat_completion_azure() -def test_embedding(): +def test_embedding(client): try: test_data = { "model": "azure/azure-embedding-model", @@ -89,13 +98,13 @@ def test_embedding(): print(len(result["data"][0]["embedding"])) assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test # test_embedding() -def test_add_new_model(): +def test_add_new_model(client): try: test_data = { "model_name": "test_openai_models", @@ -135,7 +144,7 @@ class MyCustomHandler(CustomLogger): customHandler = MyCustomHandler() -def test_chat_completion_optional_params(): +def test_chat_completion_optional_params(client): # [PROXY: PROD TEST] - DO NOT DELETE # This tests if all the /chat/completion params are passed to litellm