forked from phoenix/litellm-mirror
test guardrails with API Key
This commit is contained in:
parent
4ec6d8ff50
commit
aceab2669f
2 changed files with 106 additions and 7 deletions
|
@ -318,13 +318,33 @@ async def add_litellm_data_to_request(
|
||||||
|
|
||||||
# Guardrails
|
# Guardrails
|
||||||
move_guardrails_to_metadata(
|
move_guardrails_to_metadata(
|
||||||
data=data, _metadata_variable_name=_metadata_variable_name
|
data=data,
|
||||||
|
_metadata_variable_name=_metadata_variable_name,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
|
def move_guardrails_to_metadata(
|
||||||
|
data: dict,
|
||||||
|
_metadata_variable_name: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Heper to add guardrails from request to metadata
|
||||||
|
|
||||||
|
- If guardrails set on API Key metadata then sets guardrails on request metadata
|
||||||
|
- If guardrails not set on API key, then checks request metadata
|
||||||
|
|
||||||
|
"""
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "guardrails" in user_api_key_dict.metadata:
|
||||||
|
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
|
||||||
|
"guardrails"
|
||||||
|
]
|
||||||
|
return
|
||||||
|
|
||||||
if "guardrails" in data:
|
if "guardrails" in data:
|
||||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
del data["guardrails"]
|
del data["guardrails"]
|
||||||
|
|
|
@ -22,10 +22,6 @@ async def chat_completion(
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"guardrails": [
|
|
||||||
"aporia-post-guard",
|
|
||||||
"aporia-pre-guard",
|
|
||||||
], # default guardrails for all tests
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if guardrails is not None:
|
if guardrails is not None:
|
||||||
|
@ -41,7 +37,7 @@ async def chat_completion(
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
return response_text
|
raise Exception(response_text)
|
||||||
|
|
||||||
# response headers
|
# response headers
|
||||||
response_headers = response.headers
|
response_headers = response.headers
|
||||||
|
@ -50,6 +46,29 @@ async def chat_completion(
|
||||||
return await response.json(), response_headers
|
return await response.json(), response_headers
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_key(session, guardrails):
|
||||||
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
if guardrails:
|
||||||
|
data = {
|
||||||
|
"guardrails": guardrails,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_guard_triggered_safe_request():
|
async def test_llm_guard_triggered_safe_request():
|
||||||
"""
|
"""
|
||||||
|
@ -62,6 +81,10 @@ async def test_llm_guard_triggered_safe_request():
|
||||||
"sk-1234",
|
"sk-1234",
|
||||||
model="fake-openai-endpoint",
|
model="fake-openai-endpoint",
|
||||||
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
@ -90,6 +113,10 @@ async def test_llm_guard_triggered():
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
||||||
],
|
],
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
pytest.fail("Should have thrown an exception")
|
pytest.fail("Should have thrown an exception")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -116,3 +143,55 @@ async def test_no_llm_guard_triggered():
|
||||||
print("response=", response, "response headers", headers)
|
print("response=", response, "response headers", headers)
|
||||||
|
|
||||||
assert "x-litellm-applied-guardrails" not in headers
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guardrails_with_api_key_controls():
|
||||||
|
"""
|
||||||
|
- Make two API Keys
|
||||||
|
- Key 1 with no guardrails
|
||||||
|
- Key 2 with guardrails
|
||||||
|
- Request to Key 1 -> should be success with no guardrails
|
||||||
|
- Request to Key 2 -> should be error since guardrails are triggered
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
key_with_guardrails = await generate_key(
|
||||||
|
session=session,
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
key_with_guardrails = key_with_guardrails["key"]
|
||||||
|
|
||||||
|
key_without_guardrails = await generate_key(session=session, guardrails=None)
|
||||||
|
|
||||||
|
key_without_guardrails = key_without_guardrails["key"]
|
||||||
|
|
||||||
|
# test no guardrails triggered for key without guardrails
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_without_guardrails,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print("response=", response, "response headers", headers)
|
||||||
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
# test guardrails triggered for key with guardrails
|
||||||
|
try:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_with_guardrails,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pytest.fail("Should have thrown an exception")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
assert "Aporia detected and blocked PII" in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue