From aceab2669f7918e3fa1d89f5740946f9bfe76552 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 20 Aug 2024 08:38:14 -0700 Subject: [PATCH] test guardrails with API Key --- litellm/proxy/litellm_pre_call_utils.py | 24 ++++++- tests/otel_tests/test_guardrails.py | 89 +++++++++++++++++++++++-- 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 78f3e0949..ff16c0661 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -318,13 +318,33 @@ async def add_litellm_data_to_request( # Guardrails move_guardrails_to_metadata( - data=data, _metadata_variable_name=_metadata_variable_name + data=data, + _metadata_variable_name=_metadata_variable_name, + user_api_key_dict=user_api_key_dict, ) return data -def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str): +def move_guardrails_to_metadata( + data: dict, + _metadata_variable_name: str, + user_api_key_dict: UserAPIKeyAuth, +): + """ + Heper to add guardrails from request to metadata + + - If guardrails set on API Key metadata then sets guardrails on request metadata + - If guardrails not set on API key, then checks request metadata + + """ + if user_api_key_dict.metadata: + if "guardrails" in user_api_key_dict.metadata: + data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[ + "guardrails" + ] + return + if "guardrails" in data: data[_metadata_variable_name]["guardrails"] = data["guardrails"] del data["guardrails"] diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py index c48a5ba79..52f616bed 100644 --- a/tests/otel_tests/test_guardrails.py +++ b/tests/otel_tests/test_guardrails.py @@ -22,10 +22,6 @@ async def chat_completion( data = { "model": model, "messages": messages, - "guardrails": [ - "aporia-post-guard", - "aporia-pre-guard", - ], # default guardrails for all tests } if guardrails is not None: @@ -41,7 +37,7 @@ async def chat_completion( print() if status != 200: - return response_text + raise Exception(response_text) # response headers response_headers = response.headers @@ -50,6 +46,29 @@ async def chat_completion( return await response.json(), response_headers +async def generate_key(session, guardrails): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + if guardrails: + data = { + "guardrails": guardrails, + } + else: + data = {} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + @pytest.mark.asyncio async def test_llm_guard_triggered_safe_request(): """ @@ -62,6 +81,10 @@ async def test_llm_guard_triggered_safe_request(): "sk-1234", model="fake-openai-endpoint", messages=[{"role": "user", "content": f"Hello what's the weather"}], + guardrails=[ + "aporia-post-guard", + "aporia-pre-guard", + ], ) await asyncio.sleep(3) @@ -90,6 +113,10 @@ async def test_llm_guard_triggered(): messages=[ {"role": "user", "content": f"Hello my name is ishaan@berri.ai"} ], + guardrails=[ + "aporia-post-guard", + "aporia-pre-guard", + ], ) pytest.fail("Should have thrown an exception") except Exception as e: @@ -116,3 +143,55 @@ async def test_no_llm_guard_triggered(): print("response=", response, "response headers", headers) assert "x-litellm-applied-guardrails" not in headers + + +@pytest.mark.asyncio +async def test_guardrails_with_api_key_controls(): + """ + - Make two API Keys + - Key 1 with no guardrails + - Key 2 with guardrails + - Request to Key 1 -> should be success with no guardrails + - Request to Key 2 -> should be error since guardrails are triggered + """ + async with aiohttp.ClientSession() as session: + key_with_guardrails = await generate_key( + session=session, + guardrails=[ + "aporia-post-guard", + "aporia-pre-guard", + ], + ) + + key_with_guardrails = key_with_guardrails["key"] + + key_without_guardrails = await generate_key(session=session, guardrails=None) + + key_without_guardrails = key_without_guardrails["key"] + + # test no guardrails triggered for key without guardrails + response, headers = await chat_completion( + session, + key_without_guardrails, + model="fake-openai-endpoint", + messages=[{"role": "user", "content": f"Hello what's the weather"}], + ) + await asyncio.sleep(3) + + print("response=", response, "response headers", headers) + assert "x-litellm-applied-guardrails" not in headers + + # test guardrails triggered for key with guardrails + try: + response, headers = await chat_completion( + session, + key_with_guardrails, + model="fake-openai-endpoint", + messages=[ + {"role": "user", "content": f"Hello my name is ishaan@berri.ai"} + ], + ) + pytest.fail("Should have thrown an exception") + except Exception as e: + print(e) + assert "Aporia detected and blocked PII" in str(e)