add async_post_call_success_hook

This commit is contained in:
Ishaan Jaff 2024-08-22 16:34:43 -07:00
parent 499b6b3368
commit 9e3d573bcb
2 changed files with 50 additions and 35 deletions

View file

@ -67,10 +67,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def convert_to_bedrock_format( def convert_to_bedrock_format(
self, self,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
response: Optional[Union[Any, litellm.ModelResponse]] = None,
) -> BedrockRequest: ) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
if messages:
bedrock_request_content: List[BedrockContentItem] = [] bedrock_request_content: List[BedrockContentItem] = []
if messages:
for message in messages: for message in messages:
content = message.get("content") content = message.get("content")
if isinstance(content, str): if isinstance(content, str):
@ -80,7 +82,19 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
bedrock_request_content.append(bedrock_content_item) bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content bedrock_request["content"] = bedrock_request_content
if response:
bedrock_request["source"] = "OUTPUT"
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(
choice.message.content, str
):
bedrock_content_item = BedrockContentItem(
text=BedrockTextContent(text=choice.message.content)
)
bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content
return bedrock_request return bedrock_request
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
@ -172,11 +186,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
return prepped_request return prepped_request
async def make_bedrock_api_request(self, kwargs: dict): async def make_bedrock_api_request(
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
):
credentials, aws_region_name = self._load_credentials() credentials, aws_region_name = self._load_credentials()
request_data: BedrockRequest = self.convert_to_bedrock_format( request_data: BedrockRequest = self.convert_to_bedrock_format(
messages=kwargs.get("messages") messages=kwargs.get("messages"), response=response
) )
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
credentials=credentials, credentials=credentials,
@ -242,32 +258,32 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
) )
pass pass
# async def async_post_call_success_hook( async def async_post_call_success_hook(
# self, self,
# data: dict, data: dict,
# user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
# response, response,
# ): ):
# from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
# add_guardrail_to_applied_guardrails_header, add_guardrail_to_applied_guardrails_header,
# ) )
# from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks
# """ if (
# Use this for the post call moderation with Guardrails self.should_run_guardrail(
# """ data=data, event_type=GuardrailEventHooks.post_call
# event_type: GuardrailEventHooks = GuardrailEventHooks.post_call )
# if self.should_run_guardrail(data=data, event_type=event_type) is not True: is not True
# return ):
return
# response_str: Optional[str] = convert_litellm_response_object_to_str(response) new_messages: Optional[List[dict]] = data.get("messages")
# if response_str is not None: if new_messages is not None:
# await self.make_bedrock_api_request( await self.make_bedrock_api_request(kwargs=data, response=response)
# response_string=response_str, new_messages=data.get("messages", []) add_guardrail_to_applied_guardrails_header(
# ) request_data=data, guardrail_name=self.guardrail_name
)
# add_guardrail_to_applied_guardrails_header( else:
# request_data=data, guardrail_name=self.guardrail_name verbose_proxy_logger.warning(
# ) "Bedrock AI: not running guardrail. No messages in data"
)
# pass

View file

@ -1,14 +1,13 @@
model_list: model_list:
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: openai/fake model: openai/gpt-4
api_key: fake-key api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/
guardrails: guardrails:
- guardrail_name: "bedrock-pre-guard" - guardrail_name: "bedrock-pre-guard"
litellm_params: litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call" mode: "post_call"
guardrailIdentifier: ff6ujrregl1q guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT" guardrailVersion: "DRAFT"