forked from phoenix/litellm-mirror
add async_post_call_success_hook
This commit is contained in:
parent
499b6b3368
commit
9e3d573bcb
2 changed files with 50 additions and 35 deletions
|
@ -67,10 +67,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
def convert_to_bedrock_format(
|
||||
self,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
response: Optional[Union[Any, litellm.ModelResponse]] = None,
|
||||
) -> BedrockRequest:
|
||||
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
|
||||
bedrock_request_content: List[BedrockContentItem] = []
|
||||
|
||||
if messages:
|
||||
bedrock_request_content: List[BedrockContentItem] = []
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
|
@ -80,7 +82,19 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
bedrock_request_content.append(bedrock_content_item)
|
||||
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
|
||||
if response:
|
||||
bedrock_request["source"] = "OUTPUT"
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
bedrock_content_item = BedrockContentItem(
|
||||
text=BedrockTextContent(text=choice.message.content)
|
||||
)
|
||||
bedrock_request_content.append(bedrock_content_item)
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
return bedrock_request
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
@ -172,11 +186,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
|
||||
return prepped_request
|
||||
|
||||
async def make_bedrock_api_request(self, kwargs: dict):
|
||||
async def make_bedrock_api_request(
|
||||
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
|
||||
):
|
||||
|
||||
credentials, aws_region_name = self._load_credentials()
|
||||
request_data: BedrockRequest = self.convert_to_bedrock_format(
|
||||
messages=kwargs.get("messages")
|
||||
messages=kwargs.get("messages"), response=response
|
||||
)
|
||||
prepared_request = self._prepare_request(
|
||||
credentials=credentials,
|
||||
|
@ -242,32 +258,32 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
)
|
||||
pass
|
||||
|
||||
# async def async_post_call_success_hook(
|
||||
# self,
|
||||
# data: dict,
|
||||
# user_api_key_dict: UserAPIKeyAuth,
|
||||
# response,
|
||||
# ):
|
||||
# from litellm.proxy.common_utils.callback_utils import (
|
||||
# add_guardrail_to_applied_guardrails_header,
|
||||
# )
|
||||
# from litellm.types.guardrails import GuardrailEventHooks
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
# """
|
||||
# Use this for the post call moderation with Guardrails
|
||||
# """
|
||||
# event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
# if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
# return
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.post_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
|
||||
# response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||
# if response_str is not None:
|
||||
# await self.make_bedrock_api_request(
|
||||
# response_string=response_str, new_messages=data.get("messages", [])
|
||||
# )
|
||||
|
||||
# add_guardrail_to_applied_guardrails_header(
|
||||
# request_data=data, guardrail_name=self.guardrail_name
|
||||
# )
|
||||
|
||||
# pass
|
||||
new_messages: Optional[List[dict]] = data.get("messages")
|
||||
if new_messages is not None:
|
||||
await self.make_bedrock_api_request(kwargs=data, response=response)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Bedrock AI: not running guardrail. No messages in data"
|
||||
)
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model: openai/gpt-4
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
mode: "post_call"
|
||||
guardrailIdentifier: ff6ujrregl1q
|
||||
guardrailVersion: "DRAFT"
|
Loading…
Add table
Add a link
Reference in a new issue