From 5408a6de3d2d95e77f7696de016e00f6362a1abc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:03:23 -0800 Subject: [PATCH] (feat) proxy add amoderation endpoint --- litellm/proxy/proxy_server.py | 142 ++++++++++++++++++++++++++++++++++ litellm/proxy/utils.py | 4 +- 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3dedc3a71..9b99d08c7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2798,6 +2798,148 @@ async def image_generation( ) +@router.post( + "/v1/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +@router.post( + "/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +async def moderations( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + data["model"] = ( + general_settings.get("moderation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="moderation" + ) + + start_time = time.time() + + ## ROUTE TO CORRECT ENDPOINT ## + # skip router if user passed their key + if "api_key" in data: + response = await litellm.amoderation(**data) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + response = await llm_router.amoderation(**data) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.amoderation(**data, specific_deployment=True) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.amoderation( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping + elif user_model is not None: # `litellm --model ` + response = await litellm.amoderation(**data) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Invalid model name passed in"}, + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + #### KEY MANAGEMENT #### diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0350d54bd..4f28e2adb 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -93,7 +93,9 @@ class ProxyLogging: self, user_api_key_dict: UserAPIKeyAuth, data: dict, - call_type: Literal["completion", "embeddings", "image_generation"], + call_type: Literal[ + "completion", "embeddings", "image_generation", "moderation" + ], ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.