diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 6c7ea4d90..d11f58a3e 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -67,10 +67,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): def convert_to_bedrock_format( self, messages: Optional[List[Dict[str, str]]] = None, + response: Optional[Union[Any, litellm.ModelResponse]] = None, ) -> BedrockRequest: bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") + bedrock_request_content: List[BedrockContentItem] = [] + if messages: - bedrock_request_content: List[BedrockContentItem] = [] for message in messages: content = message.get("content") if isinstance(content, str): @@ -80,7 +82,19 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): bedrock_request_content.append(bedrock_content_item) bedrock_request["content"] = bedrock_request_content - + if response: + bedrock_request["source"] = "OUTPUT" + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + if choice.message.content and isinstance( + choice.message.content, str + ): + bedrock_content_item = BedrockContentItem( + text=BedrockTextContent(text=choice.message.content) + ) + bedrock_request_content.append(bedrock_content_item) + bedrock_request["content"] = bedrock_request_content return bedrock_request #### CALL HOOKS - proxy only #### @@ -172,11 +186,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): return prepped_request - async def make_bedrock_api_request(self, kwargs: dict): + async def make_bedrock_api_request( + self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None + ): credentials, aws_region_name = self._load_credentials() request_data: BedrockRequest = self.convert_to_bedrock_format( - messages=kwargs.get("messages") + messages=kwargs.get("messages"), response=response ) prepared_request = self._prepare_request( credentials=credentials, @@ -242,32 +258,32 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): ) pass - # async def async_post_call_success_hook( - # self, - # data: dict, - # user_api_key_dict: UserAPIKeyAuth, - # response, - # ): - # from litellm.proxy.common_utils.callback_utils import ( - # add_guardrail_to_applied_guardrails_header, - # ) - # from litellm.types.guardrails import GuardrailEventHooks + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks - # """ - # Use this for the post call moderation with Guardrails - # """ - # event_type: GuardrailEventHooks = GuardrailEventHooks.post_call - # if self.should_run_guardrail(data=data, event_type=event_type) is not True: - # return + if ( + self.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.post_call + ) + is not True + ): + return - # response_str: Optional[str] = convert_litellm_response_object_to_str(response) - # if response_str is not None: - # await self.make_bedrock_api_request( - # response_string=response_str, new_messages=data.get("messages", []) - # ) - - # add_guardrail_to_applied_guardrails_header( - # request_data=data, guardrail_name=self.guardrail_name - # ) - - # pass + new_messages: Optional[List[dict]] = data.get("messages") + if new_messages is not None: + await self.make_bedrock_api_request(kwargs=data, response=response) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Bedrock AI: not running guardrail. No messages in data" + ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index d8e88cec7..d0ed9a699 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,14 +1,13 @@ model_list: - model_name: gpt-4 litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY guardrails: - guardrail_name: "bedrock-pre-guard" litellm_params: guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" - mode: "during_call" + mode: "post_call" guardrailIdentifier: ff6ujrregl1q guardrailVersion: "DRAFT" \ No newline at end of file