diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py index ce8de6eca..6529ddcba 100644 --- a/enterprise/enterprise_hooks/aporio_ai.py +++ b/enterprise/enterprise_hooks/aporio_ai.py @@ -5,12 +5,13 @@ # +-------------------------------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan -import sys, os +import sys +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from typing import Optional, Literal, Union +from typing import Optional, Literal, Union, Any import litellm, traceback, sys, uuid from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth @@ -18,6 +19,9 @@ from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_str, +) from typing import List from datetime import datetime import aiohttp, asyncio @@ -57,6 +61,92 @@ class _ENTERPRISE_Aporio(CustomLogger): return new_messages + async def prepare_aporia_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ) -> dict: + data: dict[str, Any] = {} + if new_messages is not None: + data["messages"] = new_messages + if response_string is not None: + data["response"] = response_string + + # Set validation target + if new_messages and response_string: + data["validation_target"] = "both" + elif new_messages: + data["validation_target"] = "prompt" + elif response_string: + data["validation_target"] = "response" + + verbose_proxy_logger.debug("Aporia AI request: %s", data) + return data + + async def make_aporia_api_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ): + data = await self.prepare_aporia_request( + new_messages=new_messages, response_string=response_string + ) + + _json_data = json.dumps(data) + + """ + export APORIO_API_KEY= + curl https://gr-prd-trial.aporia.com/some-id \ + -X POST \ + -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "This is a test prompt" + } + ], + } +' + """ + + response = await self.async_handler.post( + url=self.aporio_api_base + "/validate", + data=_json_data, + headers={ + "X-APORIA-API-KEY": self.aporio_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Aporia AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + action: str = _json_response.get( + "action" + ) # possible values are modify, passthrough, block, rephrase + if action == "block": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "aporio_ai_response": _json_response, + }, + ) + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Use this for the post call moderation with Guardrails + """ + response_str: Optional[str] = convert_litellm_response_object_to_str(response) + if response_str is not None: + await self.make_aporia_api_request( + response_string=response_str, new_messages=[] + ) + + pass + async def async_moderation_hook( ### 👈 KEY CHANGE ### self, data: dict, @@ -78,47 +168,9 @@ class _ENTERPRISE_Aporio(CustomLogger): new_messages = self.transform_messages(messages=data["messages"]) if new_messages is not None: - data = {"messages": new_messages, "validation_target": "prompt"} - - _json_data = json.dumps(data) - - """ - export APORIO_API_KEY= - curl https://gr-prd-trial.aporia.com/some-id \ - -X POST \ - -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - { - "role": "user", - "content": "This is a test prompt" - } - ], - } -' - """ - - response = await self.async_handler.post( - url=self.aporio_api_base + "/validate", - data=_json_data, - headers={ - "X-APORIA-API-KEY": self.aporio_api_key, - "Content-Type": "application/json", - }, + await self.make_aporia_api_request(new_messages=new_messages) + else: + verbose_proxy_logger.warning( + "Aporia AI: not running guardrail. No messages in data" ) - verbose_proxy_logger.debug("Aporio AI response: %s", response.text) - if response.status_code == 200: - # check if the response was flagged - _json_response = response.json() - action: str = _json_response.get( - "action" - ) # possible values are modify, passthrough, block, rephrase - if action == "block": - raise HTTPException( - status_code=400, - detail={ - "error": "Violated guardrail policy", - "aporio_ai_response": _json_response, - }, - ) + pass diff --git a/litellm/litellm_core_utils/logging_utils.py b/litellm/litellm_core_utils/logging_utils.py index fdc9672a0..7fa1be9d8 100644 --- a/litellm/litellm_core_utils/logging_utils.py +++ b/litellm/litellm_core_utils/logging_utils.py @@ -1,4 +1,12 @@ -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from litellm import ModelResponse as _ModelResponse + + LiteLLMModelResponse = _ModelResponse +else: + LiteLLMModelResponse = Any + import litellm @@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict: # If it's not a LiteLLM type, return the object as is return dict(response_obj) + + +def convert_litellm_response_object_to_str( + response_obj: Union[Any, LiteLLMModelResponse] +) -> Optional[str]: + """ + Get the string of the response object from LiteLLM + + """ + if isinstance(response_obj, litellm.ModelResponse): + response_str = "" + for choice in response_obj.choices: + if isinstance(choice, litellm.Choices): + if choice.message.content and isinstance(choice.message.content, str): + response_str += choice.message.content + return response_str + + return None diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e08be88aa..902dab7ad 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,50 +1,11 @@ model_list: - - model_name: gpt-4 + - model_name: gpt-3.5-turbo litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - model_info: - access_groups: ["beta-models"] - - model_name: fireworks-llama-v3-70b-instruct - litellm_params: - model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct - api_key: "os.environ/FIREWORKS" - model_info: - access_groups: ["beta-models"] - - model_name: "*" - litellm_params: - model: "*" - - model_name: "*" - litellm_params: - model: openai/* + model: openai/gpt-3.5-turbo api_key: os.environ/OPENAI_API_KEY - - model_name: mistral-small-latest - litellm_params: - model: mistral/mistral-small-latest - api_key: "os.environ/MISTRAL_API_KEY" - - model_name: bedrock-anthropic - litellm_params: - model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 - - model_name: gemini-1.5-pro-001 - litellm_params: - model: vertex_ai_beta/gemini-1.5-pro-001 - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" - # Add path to service account.json - -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json - - -general_settings: - master_key: sk-1234 - alerting: ["slack"] litellm_settings: - fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] - success_callback: ["langfuse", "prometheus"] - langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"] + guardrails: + - prompt_injection: + callbacks: [aporio_prompt_injection] + default_on: true