From 2a4a6995ac41eeb0584e025365f5adf8d6d4d942 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 16 Feb 2024 18:45:25 -0800 Subject: [PATCH] feat(llama_guard.py): add llama guard support for content moderation + new `async_moderation_hook` endpoint --- enterprise/hooks/llama_guard.py | 71 +++++++++++++++++++ .../proxy/enterprise => enterprise}/utils.py | 0 litellm/__init__.py | 71 +++++++++++-------- litellm/integrations/custom_logger.py | 3 + litellm/llms/prompt_templates/factory.py | 5 ++ litellm/proxy/enterprise | 1 + litellm/proxy/enterprise/LICENSE.md | 37 ---------- litellm/proxy/enterprise/README.md | 12 ---- .../callbacks/example_logging_api.py | 31 -------- litellm/proxy/proxy_server.py | 33 +++++++-- litellm/proxy/utils.py | 21 +++--- litellm/utils.py | 10 ++- 12 files changed, 163 insertions(+), 132 deletions(-) create mode 100644 enterprise/hooks/llama_guard.py rename {litellm/proxy/enterprise => enterprise}/utils.py (100%) create mode 120000 litellm/proxy/enterprise delete mode 100644 litellm/proxy/enterprise/LICENSE.md delete mode 100644 litellm/proxy/enterprise/README.md delete mode 100644 litellm/proxy/enterprise/callbacks/example_logging_api.py diff --git a/enterprise/hooks/llama_guard.py b/enterprise/hooks/llama_guard.py new file mode 100644 index 000000000..c4f45909e --- /dev/null +++ b/enterprise/hooks/llama_guard.py @@ -0,0 +1,71 @@ +# +-------------------------------------------------------------+ +# +# Llama Guard +# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main +# +# LLM for Content Moderation +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys, os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) +from datetime import datetime +import aiohttp, asyncio + +litellm.set_verbose = True + + +class _ENTERPRISE_LlamaGuard(CustomLogger): + # Class variables or attributes + def __init__(self, model_name: Optional[str] = None): + self.model = model_name or litellm.llamaguard_model_name + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + async def async_moderation_hook( + self, + data: dict, + ): + """ + - Calls the Llama Guard Endpoint + - Rejects request if it fails safety check + + The llama guard prompt template is applied automatically in factory.py + """ + safety_check_messages = data["messages"][ + -1 + ] # get the last response - llama guard has a 4k token limit + response = await litellm.acompletion( + model=self.model, + messages=[safety_check_messages], + hf_model_name="meta-llama/LlamaGuard-7b", + ) + + if "unsafe" in response.choices[0].message.content: + raise HTTPException( + status_code=400, detail={"error": "Violated content safety policy"} + ) + + return data diff --git a/litellm/proxy/enterprise/utils.py b/enterprise/utils.py similarity index 100% rename from litellm/proxy/enterprise/utils.py rename to enterprise/utils.py diff --git a/litellm/__init__.py b/litellm/__init__.py index a7f232f76..c263c8e8e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -16,23 +16,23 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] -_async_input_callback: List[ - Callable -] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[ - Union[str, Callable] -] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[ - Callable -] = [] # internal variable - async custom callbacks are routed here. +_async_input_callback: List[Callable] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_success_callback: List[Union[str, Callable]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[Callable] = ( + [] +) # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +email: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +token: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = False @@ -55,18 +55,23 @@ baseten_key: Optional[str] = None aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None use_client: bool = False +llamaguard_model_name: Optional[str] = None logging: bool = True -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - Cache -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +caching: bool = ( + False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +caching_with_models: bool = ( + False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +cache: Optional[Cache] = ( + None # cache object <- use this - https://docs.litellm.ai/docs/caching +) model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +budget_duration: Optional[str] = ( + None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +) _openai_finish_reasons = ["stop", "length", "function_call", "content_filter", "null"] _openai_completion_params = [ "functions", @@ -138,11 +143,15 @@ _litellm_completion_params = [ ] _current_cost = 0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = ( + False # if function calling not supported by api, append function call details to system prompt +) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +model_cost_map_url: str = ( + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +) suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -157,13 +166,13 @@ num_retries: Optional[int] = None # per model endpoint fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None allowed_fails: int = 0 -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) +num_retries_per_request: Optional[int] = ( + None # for the request overall (incl. fallbacks + model retries) +) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[Any] = ( + None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None #### PII MASKING #### diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d0cdd7702..c29c964fc 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -72,6 +72,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass + async def async_moderation_hook(self, data: dict): + pass + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 7896d7c96..4ed4d9295 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -706,6 +706,11 @@ def prompt_factory( "gryphe/mythologic-l2-13b", ]: return alpaca_pt(messages=messages) + elif model.strip() == "meta-llama/llamaguard-7b": + chat_template = "{% if messages|length % 2 == 0 %}{% set role = 'Agent' %}{% else %}{% set role = 'User' %}{% endif %}{{ \"[INST] Task: Check if there is unsafe content in '\" + role + \"' messages in conversations according our safety policy with the below categories.\n\n\nO1: Violence and Hate. \nShould not\n- Help people plan or engage in violence.\n- Encourage people to engage in violence.\n- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\nCan\n- Provide information on violence and discrimination.\n- Discuss the topics of hate and violence, and can discuss historical events involving violence.\nO2: Sexual Content.\nShould not\n- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\nCan\n- Discuss topics related to sexuality and sexual education.\n- Provide general information regarding sexual activities.\n- Provide guidance on topics related to sexual health.\nO3: Criminal Planning.\nShould not\n- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n- Encourage people to engage in criminal activities or condone criminal behavior.\nCan\n- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\nO4: Guns and Illegal Weapons.\nShould not\n- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\nCan\n- Discuss firearms and the arguments for and against firearm ownership.\nO5: Regulated or Controlled Substances.\nShould not\n- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.\n- Assist or encourage people to create such substances when it is illegal to do so.\nCan\n- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).\n- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.\nO6: Self-Harm.\nShould not\n- Encourage people to harm themselves.\n- Romanticize or condone self-harm or suicide.\n- Provide information regarding the methods of suicide or self-harm.\n- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\nCan\n- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\nShould\n- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n\n\n\n\n\" }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{% set role = 'User' %}{% elif message['role'] == 'assistant' %}{% set role = 'Agent' %}{% endif %}{{ role + ': ' + content.strip() + '\n\n' }}{% endfor %}{{ \"\n\nProvide your safety assessment for \" + role + \" in the above conversation:\n- First line must read 'safe' or 'unsafe'.\n- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]\" }}" + return hf_chat_template( + model=model, messages=messages, chat_template=chat_template + ) else: return hf_chat_template(original_model_name, messages) except Exception as e: diff --git a/litellm/proxy/enterprise b/litellm/proxy/enterprise new file mode 120000 index 000000000..6ee73080d --- /dev/null +++ b/litellm/proxy/enterprise @@ -0,0 +1 @@ +../../enterprise \ No newline at end of file diff --git a/litellm/proxy/enterprise/LICENSE.md b/litellm/proxy/enterprise/LICENSE.md deleted file mode 100644 index 5cd298ce6..000000000 --- a/litellm/proxy/enterprise/LICENSE.md +++ /dev/null @@ -1,37 +0,0 @@ - -The BerriAI Enterprise license (the "Enterprise License") -Copyright (c) 2024 - present Berrie AI Inc. - -With regard to the BerriAI Software: - -This software and associated documentation files (the "Software") may only be -used in production, if you (and any entity that you represent) have agreed to, -and are in compliance with, the BerriAI Subscription Terms of Service, available -via [call](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) or email (info@berri.ai) (the "Enterprise Terms"), or other -agreement governing the use of the Software, as agreed by you and BerriAI, -and otherwise have a valid BerriAI Enterprise license for the -correct number of user seats. Subject to the foregoing sentence, you are free to -modify this Software and publish patches to the Software. You agree that BerriAI -and/or its licensors (as applicable) retain all right, title and interest in and -to all such modifications and/or patches, and all such modifications and/or -patches may only be used, copied, modified, displayed, distributed, or otherwise -exploited with a valid BerriAI Enterprise license for the correct -number of user seats. Notwithstanding the foregoing, you may copy and modify -the Software for development and testing purposes, without requiring a -subscription. You agree that BerriAI and/or its licensors (as applicable) retain -all right, title and interest in and to all such modifications. You are not -granted any other rights beyond what is expressly stated herein. Subject to the -foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, -and/or sell the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -For all third party components incorporated into the BerriAI Software, those -components are licensed under the original license provided by the owner of the -applicable component. \ No newline at end of file diff --git a/litellm/proxy/enterprise/README.md b/litellm/proxy/enterprise/README.md deleted file mode 100644 index fd7e68fd2..000000000 --- a/litellm/proxy/enterprise/README.md +++ /dev/null @@ -1,12 +0,0 @@ -## LiteLLM Enterprise - -Code in this folder is licensed under a commercial license. Please review the [LICENSE](./LICENSE.md) file within the /enterprise folder - -**These features are covered under the LiteLLM Enterprise contract** - -👉 **Using in an Enterprise / Need specific features ?** Meet with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat?month=2024-02) - -## Features: -- Custom API / microservice callbacks -- Google Text Moderation API - diff --git a/litellm/proxy/enterprise/callbacks/example_logging_api.py b/litellm/proxy/enterprise/callbacks/example_logging_api.py deleted file mode 100644 index a8c5b5429..000000000 --- a/litellm/proxy/enterprise/callbacks/example_logging_api.py +++ /dev/null @@ -1,31 +0,0 @@ -# this is an example endpoint to receive data from litellm -from fastapi import FastAPI, HTTPException, Request - -app = FastAPI() - - -@app.post("/log-event") -async def log_event(request: Request): - try: - print("Received /log-event request") # noqa - # Assuming the incoming request has JSON data - data = await request.json() - print("Received request data:") # noqa - print(data) # noqa - - # Your additional logic can go here - # For now, just printing the received data - - return {"message": "Request received successfully"} - except Exception as e: - print(f"Error processing request: {str(e)}") # noqa - import traceback - - traceback.print_exc() - raise HTTPException(status_code=500, detail="Internal Server Error") - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 37f55072e..b5349fffb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1368,7 +1368,7 @@ class ProxyConfig: ) elif key == "callbacks": if isinstance(value, list): - imported_list = [] + imported_list: List[Any] = [] for callback in value: # ["presidio", ] if isinstance(callback, str) and callback == "presidio": from litellm.proxy.hooks.presidio_pii_masking import ( @@ -1377,6 +1377,16 @@ class ProxyConfig: pii_masking_object = _OPTIONAL_PresidioPIIMasking() imported_list.append(pii_masking_object) + elif ( + isinstance(callback, str) + and callback == "llamaguard_moderations" + ): + from litellm.proxy.enterprise.hooks.llama_guard import ( + _ENTERPRISE_LlamaGuard, + ) + + llama_guard_object = _ENTERPRISE_LlamaGuard() + imported_list.append(llama_guard_object) else: imported_list.append( get_instance_fn( @@ -2423,6 +2433,9 @@ async def chat_completion( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) + tasks = [] + tasks.append(proxy_logging_obj.during_call_hook(data=data)) + start_time = time.time() ### ROUTE THE REQUEST ### @@ -2433,34 +2446,40 @@ async def chat_completion( ) # skip router if user passed their key if "api_key" in data: - response = await litellm.acompletion(**data) + tasks.append(litellm.acompletion(**data)) elif "user_config" in data: # initialize a new router instance. make request using this Router router_config = data.pop("user_config") user_router = litellm.Router(**router_config) - response = await user_router.acompletion(**data) + tasks.append(user_router.acompletion(**data)) elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list - response = await llm_router.acompletion(**data) + tasks.append(llm_router.acompletion(**data)) elif ( llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias ): # model set in model_group_alias - response = await llm_router.acompletion(**data) + tasks.append(llm_router.acompletion(**data)) elif ( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.acompletion(**data, specific_deployment=True) + tasks.append(llm_router.acompletion(**data, specific_deployment=True)) elif user_model is not None: # `litellm --model ` - response = await litellm.acompletion(**data) + tasks.append(litellm.acompletion(**data)) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": "Invalid model name passed in"}, ) + # wait for call to end + responses = await asyncio.gather( + *tasks + ) # run the moderation check in parallel to the actual llm api call + response = responses[1] + # Post Call Processing data["litellm_status"] = "success" # used for alerting if hasattr(response, "_hidden_params"): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 616f99f40..7cc0f59f1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -128,19 +128,18 @@ class ProxyLogging: except Exception as e: raise e - async def success_handler( - self, - user_api_key_dict: UserAPIKeyAuth, - response: Any, - call_type: Literal["completion", "embeddings"], - start_time, - end_time, - ): + async def during_call_hook(self, data: dict): """ - Log successful API calls / db read/writes + Runs the CustomLogger's async_moderation_hook() """ - - pass + for callback in litellm.callbacks: + new_data = copy.deepcopy(data) + try: + if isinstance(callback, CustomLogger): + await callback.async_moderation_hook(data=new_data) + except Exception as e: + raise e + return data async def response_taking_too_long( self, diff --git a/litellm/utils.py b/litellm/utils.py index 4dece6ab8..d4406e398 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1141,7 +1141,7 @@ class Logging: if ( litellm.max_budget - and self.stream + and self.stream == False and result is not None and "content" in result ): @@ -1668,7 +1668,9 @@ class Logging: end_time=end_time, ) if callable(callback): # custom logger functions - print_verbose(f"Making async function logging call") + print_verbose( + f"Making async function logging call - {self.model_call_details}" + ) if self.stream: if "complete_streaming_response" in self.model_call_details: await customLogger.async_log_event( @@ -3451,7 +3453,7 @@ def cost_per_token( return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar else: # if model is not in model_prices_and_context_window.json. Raise an exception-let users know - error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}\n" + error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}. Register pricing for model - https://docs.litellm.ai/docs/proxy/custom_pricing\n" raise litellm.exceptions.NotFoundError( # type: ignore message=error_str, model=model, @@ -3913,6 +3915,8 @@ def get_optional_params( custom_llm_provider != "bedrock" and custom_llm_provider != "sagemaker" ): # allow dynamically setting boto3 init logic continue + elif k == "hf_model_name" and custom_llm_provider != "sagemaker": + continue elif ( k.startswith("vertex_") and custom_llm_provider != "vertex_ai" ): # allow dynamically setting vertex ai init logic