feat run aporia as post call success hook

This commit is contained in:
Ishaan Jaff 2024-08-19 11:25:31 -07:00
parent 601be5cb44
commit 8cb62213e1
3 changed files with 130 additions and 91 deletions

View file

@ -5,12 +5,13 @@
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
from typing import Optional, Literal, Union, Any
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
@ -18,6 +19,9 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str,
)
from typing import List
from datetime import datetime
import aiohttp, asyncio
@ -57,6 +61,92 @@ class _ENTERPRISE_Aporio(CustomLogger):
return new_messages
async def prepare_aporia_request(
self, new_messages: List[dict], response_string: Optional[str] = None
) -> dict:
data: dict[str, Any] = {}
if new_messages is not None:
data["messages"] = new_messages
if response_string is not None:
data["response"] = response_string
# Set validation target
if new_messages and response_string:
data["validation_target"] = "both"
elif new_messages:
data["validation_target"] = "prompt"
elif response_string:
data["validation_target"] = "response"
verbose_proxy_logger.debug("Aporia AI request: %s", data)
return data
async def make_aporia_api_request(
self, new_messages: List[dict], response_string: Optional[str] = None
):
data = await self.prepare_aporia_request(
new_messages=new_messages, response_string=response_string
)
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporio_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporio_api_key,
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporio_ai_response": _json_response,
},
)
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Use this for the post call moderation with Guardrails
"""
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
response_string=response_str, new_messages=[]
)
pass
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
@ -78,47 +168,9 @@ class _ENTERPRISE_Aporio(CustomLogger):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
data = {"messages": new_messages, "validation_target": "prompt"}
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporio_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporio_api_key,
"Content-Type": "application/json",
},
await self.make_aporia_api_request(new_messages=new_messages)
else:
verbose_proxy_logger.warning(
"Aporia AI: not running guardrail. No messages in data"
)
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporio_ai_response": _json_response,
},
)
pass

View file

@ -1,4 +1,12 @@
from typing import Any
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from litellm import ModelResponse as _ModelResponse
LiteLLMModelResponse = _ModelResponse
else:
LiteLLMModelResponse = Any
import litellm
@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict:
# If it's not a LiteLLM type, return the object as is
return dict(response_obj)
def convert_litellm_response_object_to_str(
response_obj: Union[Any, LiteLLMModelResponse]
) -> Optional[str]:
"""
Get the string of the response object from LiteLLM
"""
if isinstance(response_obj, litellm.ModelResponse):
response_str = ""
for choice in response_obj.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(choice.message.content, str):
response_str += choice.message.content
return response_str
return None

View file

@ -1,50 +1,11 @@
model_list:
- model_name: gpt-4
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
access_groups: ["beta-models"]
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
model_info:
access_groups: ["beta-models"]
- model_name: "*"
litellm_params:
model: "*"
- model_name: "*"
litellm_params:
model: openai/*
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: mistral-small-latest
litellm_params:
model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY"
- model_name: bedrock-anthropic
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- model_name: gemini-1.5-pro-001
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
# Add path to service account.json
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
general_settings:
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
success_callback: ["langfuse", "prometheus"]
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
guardrails:
- prompt_injection:
callbacks: [aporio_prompt_injection]
default_on: true