add async_post_call_success_hook

This commit is contained in:
Ishaan Jaff 2024-08-22 16:34:43 -07:00
parent 499b6b3368
commit 9e3d573bcb
2 changed files with 50 additions and 35 deletions

View file

@ -67,10 +67,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def convert_to_bedrock_format(
self,
messages: Optional[List[Dict[str, str]]] = None,
response: Optional[Union[Any, litellm.ModelResponse]] = None,
) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
bedrock_request_content: List[BedrockContentItem] = []
if messages:
bedrock_request_content: List[BedrockContentItem] = []
for message in messages:
content = message.get("content")
if isinstance(content, str):
@ -80,7 +82,19 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content
if response:
bedrock_request["source"] = "OUTPUT"
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(
choice.message.content, str
):
bedrock_content_item = BedrockContentItem(
text=BedrockTextContent(text=choice.message.content)
)
bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content
return bedrock_request
#### CALL HOOKS - proxy only ####
@ -172,11 +186,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
return prepped_request
async def make_bedrock_api_request(self, kwargs: dict):
async def make_bedrock_api_request(
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
):
credentials, aws_region_name = self._load_credentials()
request_data: BedrockRequest = self.convert_to_bedrock_format(
messages=kwargs.get("messages")
messages=kwargs.get("messages"), response=response
)
prepared_request = self._prepare_request(
credentials=credentials,
@ -242,32 +258,32 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
)
pass
# async def async_post_call_success_hook(
# self,
# data: dict,
# user_api_key_dict: UserAPIKeyAuth,
# response,
# ):
# from litellm.proxy.common_utils.callback_utils import (
# add_guardrail_to_applied_guardrails_header,
# )
# from litellm.types.guardrails import GuardrailEventHooks
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
# """
# Use this for the post call moderation with Guardrails
# """
# event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
# if self.should_run_guardrail(data=data, event_type=event_type) is not True:
# return
if (
self.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.post_call
)
is not True
):
return
# response_str: Optional[str] = convert_litellm_response_object_to_str(response)
# if response_str is not None:
# await self.make_bedrock_api_request(
# response_string=response_str, new_messages=data.get("messages", [])
# )
# add_guardrail_to_applied_guardrails_header(
# request_data=data, guardrail_name=self.guardrail_name
# )
# pass
new_messages: Optional[List[dict]] = data.get("messages")
if new_messages is not None:
await self.make_bedrock_api_request(kwargs=data, response=response)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(
"Bedrock AI: not running guardrail. No messages in data"
)

View file

@ -1,14 +1,13 @@
model_list:
- model_name: gpt-4
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model: openai/gpt-4
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
mode: "post_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"