forked from phoenix/litellm-mirror
feat run aporia as post call success hook
This commit is contained in:
parent
601be5cb44
commit
8cb62213e1
3 changed files with 130 additions and 91 deletions
|
@ -5,12 +5,13 @@
|
|||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional, Literal, Union
|
||||
from typing import Optional, Literal, Union, Any
|
||||
import litellm, traceback, sys, uuid
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
@ -18,6 +19,9 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.litellm_core_utils.logging_utils import (
|
||||
convert_litellm_response_object_to_str,
|
||||
)
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
|
@ -57,28 +61,32 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
|
||||
return new_messages
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
data=data,
|
||||
guardrail_name=GUARDRAIL_NAME,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = None
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
new_messages = self.transform_messages(messages=data["messages"])
|
||||
|
||||
async def prepare_aporia_request(
|
||||
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||
) -> dict:
|
||||
data: dict[str, Any] = {}
|
||||
if new_messages is not None:
|
||||
data = {"messages": new_messages, "validation_target": "prompt"}
|
||||
data["messages"] = new_messages
|
||||
if response_string is not None:
|
||||
data["response"] = response_string
|
||||
|
||||
# Set validation target
|
||||
if new_messages and response_string:
|
||||
data["validation_target"] = "both"
|
||||
elif new_messages:
|
||||
data["validation_target"] = "prompt"
|
||||
elif response_string:
|
||||
data["validation_target"] = "response"
|
||||
|
||||
verbose_proxy_logger.debug("Aporia AI request: %s", data)
|
||||
return data
|
||||
|
||||
async def make_aporia_api_request(
|
||||
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||
):
|
||||
data = await self.prepare_aporia_request(
|
||||
new_messages=new_messages, response_string=response_string
|
||||
)
|
||||
|
||||
_json_data = json.dumps(data)
|
||||
|
||||
|
@ -107,7 +115,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
|
||||
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
# check if the response was flagged
|
||||
_json_response = response.json()
|
||||
|
@ -122,3 +130,47 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
"aporio_ai_response": _json_response,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
"""
|
||||
Use this for the post call moderation with Guardrails
|
||||
"""
|
||||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||
if response_str is not None:
|
||||
await self.make_aporia_api_request(
|
||||
response_string=response_str, new_messages=[]
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
data=data,
|
||||
guardrail_name=GUARDRAIL_NAME,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = None
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
new_messages = self.transform_messages(messages=data["messages"])
|
||||
|
||||
if new_messages is not None:
|
||||
await self.make_aporia_api_request(new_messages=new_messages)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Aporia AI: not running guardrail. No messages in data"
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse as _ModelResponse
|
||||
|
||||
LiteLLMModelResponse = _ModelResponse
|
||||
else:
|
||||
LiteLLMModelResponse = Any
|
||||
|
||||
|
||||
import litellm
|
||||
|
||||
|
@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict:
|
|||
|
||||
# If it's not a LiteLLM type, return the object as is
|
||||
return dict(response_obj)
|
||||
|
||||
|
||||
def convert_litellm_response_object_to_str(
|
||||
response_obj: Union[Any, LiteLLMModelResponse]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the string of the response object from LiteLLM
|
||||
|
||||
"""
|
||||
if isinstance(response_obj, litellm.ModelResponse):
|
||||
response_str = ""
|
||||
for choice in response_obj.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
if choice.message.content and isinstance(choice.message.content, str):
|
||||
response_str += choice.message.content
|
||||
return response_str
|
||||
|
||||
return None
|
||||
|
|
|
@ -1,50 +1,11 @@
|
|||
model_list:
|
||||
- model_name: gpt-4
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model_info:
|
||||
access_groups: ["beta-models"]
|
||||
- model_name: fireworks-llama-v3-70b-instruct
|
||||
litellm_params:
|
||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||
api_key: "os.environ/FIREWORKS"
|
||||
model_info:
|
||||
access_groups: ["beta-models"]
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: mistral-small-latest
|
||||
litellm_params:
|
||||
model: mistral/mistral-small-latest
|
||||
api_key: "os.environ/MISTRAL_API_KEY"
|
||||
- model_name: bedrock-anthropic
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||
- model_name: gemini-1.5-pro-001
|
||||
litellm_params:
|
||||
model: vertex_ai_beta/gemini-1.5-pro-001
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
|
||||
# Add path to service account.json
|
||||
|
||||
default_vertex_config:
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
|
||||
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
|
||||
litellm_settings:
|
||||
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
||||
success_callback: ["langfuse", "prometheus"]
|
||||
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
|
||||
guardrails:
|
||||
- prompt_injection:
|
||||
callbacks: [aporio_prompt_injection]
|
||||
default_on: true
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue