test guardrails with API Key

This commit is contained in:
Ishaan Jaff 2024-08-20 08:38:14 -07:00
parent 4ec6d8ff50
commit aceab2669f
2 changed files with 106 additions and 7 deletions

View file

@ -318,13 +318,33 @@ async def add_litellm_data_to_request(
# Guardrails # Guardrails
move_guardrails_to_metadata( move_guardrails_to_metadata(
data=data, _metadata_variable_name=_metadata_variable_name data=data,
_metadata_variable_name=_metadata_variable_name,
user_api_key_dict=user_api_key_dict,
) )
return data return data
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str): def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Heper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
"""
if user_api_key_dict.metadata:
if "guardrails" in user_api_key_dict.metadata:
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
"guardrails"
]
return
if "guardrails" in data: if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"] data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"] del data["guardrails"]

View file

@ -22,10 +22,6 @@ async def chat_completion(
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
"guardrails": [
"aporia-post-guard",
"aporia-pre-guard",
], # default guardrails for all tests
} }
if guardrails is not None: if guardrails is not None:
@ -41,7 +37,7 @@ async def chat_completion(
print() print()
if status != 200: if status != 200:
return response_text raise Exception(response_text)
# response headers # response headers
response_headers = response.headers response_headers = response.headers
@ -50,6 +46,29 @@ async def chat_completion(
return await response.json(), response_headers return await response.json(), response_headers
async def generate_key(session, guardrails):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
if guardrails:
data = {
"guardrails": guardrails,
}
else:
data = {}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_guard_triggered_safe_request(): async def test_llm_guard_triggered_safe_request():
""" """
@ -62,6 +81,10 @@ async def test_llm_guard_triggered_safe_request():
"sk-1234", "sk-1234",
model="fake-openai-endpoint", model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}], messages=[{"role": "user", "content": f"Hello what's the weather"}],
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
) )
await asyncio.sleep(3) await asyncio.sleep(3)
@ -90,6 +113,10 @@ async def test_llm_guard_triggered():
messages=[ messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"} {"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
], ],
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
) )
pytest.fail("Should have thrown an exception") pytest.fail("Should have thrown an exception")
except Exception as e: except Exception as e:
@ -116,3 +143,55 @@ async def test_no_llm_guard_triggered():
print("response=", response, "response headers", headers) print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers assert "x-litellm-applied-guardrails" not in headers
@pytest.mark.asyncio
async def test_guardrails_with_api_key_controls():
"""
- Make two API Keys
- Key 1 with no guardrails
- Key 2 with guardrails
- Request to Key 1 -> should be success with no guardrails
- Request to Key 2 -> should be error since guardrails are triggered
"""
async with aiohttp.ClientSession() as session:
key_with_guardrails = await generate_key(
session=session,
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
)
key_with_guardrails = key_with_guardrails["key"]
key_without_guardrails = await generate_key(session=session, guardrails=None)
key_without_guardrails = key_without_guardrails["key"]
# test no guardrails triggered for key without guardrails
response, headers = await chat_completion(
session,
key_without_guardrails,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers
# test guardrails triggered for key with guardrails
try:
response, headers = await chat_completion(
session,
key_with_guardrails,
model="fake-openai-endpoint",
messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
],
)
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)