diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index d8ea52be5..08ef3e388 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -103,7 +103,24 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - Use the sanitized prompt returned - LLM Guard can handle things like PII Masking, etc. """ - return data + self.print_verbose(f"Inside LLM Guard Pre-Call Hook") + try: + assert call_type in [ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ] + except Exception as e: + self.print_verbose( + f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" + ) + return data + + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") + return self.moderation_check(text=formatted_prompt) async def async_post_call_streaming_hook( self, user_api_key_dict: UserAPIKeyAuth, response: str