From 693efc8e849b2cd44b4acffe6a6e0f4affe3d6e4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:00:09 -0800 Subject: [PATCH 1/5] (feat) add moderation on router --- litellm/main.py | 32 ++++++++++--- litellm/router.py | 92 ++++++++++++++++++++++++++++++++++++ litellm/tests/test_router.py | 20 ++++++++ litellm/utils.py | 7 +++ 4 files changed, 145 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index a7990ecfb..585de7fff 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2962,16 +2962,36 @@ def text_completion( ##### Moderation ####################### -def moderation(input: str, api_key: Optional[str] = None): +@client +def moderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): # only supports open ai for now api_key = ( api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") ) - openai.api_key = api_key - openai.api_type = "open_ai" # type: ignore - openai.api_version = None - openai.base_url = "https://api.openai.com/v1/" - response = openai.moderations.create(input=input) + + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.OpenAI( + api_key=api_key, + ) + + response = openai_client.moderations.create(input=input, model=model) + return response + + +##### Moderation ####################### +@client +async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): + # only supports open ai for now + api_key = ( + api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + ) + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.AsyncOpenAI( + api_key=api_key, + ) + response = await openai_client.moderations.create(input=input, model=model) return response diff --git a/litellm/router.py b/litellm/router.py index 21e967576..b64b111a1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -599,6 +599,98 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def amoderation(self, model: str, input: str, **kwargs): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._amoderation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _amoderation(self, model: str, input: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _moderation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + timeout = ( + data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or kwargs.get( + "timeout", None + ) # this uses default_litellm_params when nothing is set + ) + + response = await litellm.amoderation( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + "timeout": timeout, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + def text_completion( self, model: str, diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b9ca29cee..ab329e14a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -991,3 +991,23 @@ def test_router_timeout(): print(e) print(vars(e)) pass + + +@pytest.mark.asyncio +async def test_router_amoderation(): + model_list = [ + { + "model_name": "openai-moderations", + "litellm_params": { + "model": "text-moderation-stable", + "api_key": os.getenv("OPENAI_API_KEY", None), + }, + } + ] + + router = Router(model_list=model_list) + result = await router.amoderation( + model="openai-moderations", input="this is valid good text" + ) + + print("moderation result", result) diff --git a/litellm/utils.py b/litellm/utils.py index e238b84d7..b15be366d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -738,6 +738,8 @@ class CallTypes(Enum): text_completion = "text_completion" image_generation = "image_generation" aimage_generation = "aimage_generation" + moderation = "moderation" + amoderation = "amoderation" # Logging function -> log the exact model details + what's being sent | Non-BlockingP @@ -2100,6 +2102,11 @@ def client(original_function): or call_type == CallTypes.aimage_generation.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.moderation.value + or call_type == CallTypes.amoderation.value + ): + messages = args[1] if len(args) > 1 else kwargs["input"] elif ( call_type == CallTypes.atext_completion.value or call_type == CallTypes.text_completion.value From 5408a6de3d2d95e77f7696de016e00f6362a1abc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:03:23 -0800 Subject: [PATCH 2/5] (feat) proxy add amoderation endpoint --- litellm/proxy/proxy_server.py | 142 ++++++++++++++++++++++++++++++++++ litellm/proxy/utils.py | 4 +- 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3dedc3a71..9b99d08c7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2798,6 +2798,148 @@ async def image_generation( ) +@router.post( + "/v1/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +@router.post( + "/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +async def moderations( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + data["model"] = ( + general_settings.get("moderation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="moderation" + ) + + start_time = time.time() + + ## ROUTE TO CORRECT ENDPOINT ## + # skip router if user passed their key + if "api_key" in data: + response = await litellm.amoderation(**data) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + response = await llm_router.amoderation(**data) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.amoderation(**data, specific_deployment=True) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.amoderation( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping + elif user_model is not None: # `litellm --model ` + response = await litellm.amoderation(**data) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Invalid model name passed in"}, + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + #### KEY MANAGEMENT #### diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0350d54bd..4f28e2adb 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -93,7 +93,9 @@ class ProxyLogging: self, user_api_key_dict: UserAPIKeyAuth, data: dict, - call_type: Literal["completion", "embeddings", "image_generation"], + call_type: Literal[ + "completion", "embeddings", "image_generation", "moderation" + ], ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. From f1dc656491c68aa4d15bda367ad4f27c5cbc7140 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:34:06 -0800 Subject: [PATCH 3/5] (docs) add moderations to docs --- docs/my-website/docs/proxy/user_keys.md | 80 ++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index 47cfef9c3..fcccffaa0 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -197,7 +197,7 @@ from openai import OpenAI # set api_key to send to proxy server client = OpenAI(api_key="", base_url="http://0.0.0.0:8000") -response = openai.embeddings.create( +response = client.embeddings.create( input=["hello from litellm"], model="text-embedding-ada-002" ) @@ -281,6 +281,84 @@ print(query_result[:5]) ``` +## `/moderations` + + +### Request Format +Input, Output and Exceptions are mapped to the OpenAI format for all supported models + + + + +```python +import openai +from openai import OpenAI + +# set base_url to your proxy server +# set api_key to send to proxy server +client = OpenAI(api_key="", base_url="http://0.0.0.0:8000") + +response = client.moderations.create( + input="hello from litellm", + model="text-moderation-stable" +) + +print(response) + +``` + + + +```shell +curl --location 'http://0.0.0.0:8000/moderations' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{"input": "Sample text goes here", "model": "text-moderation-stable"}' +``` + + + + +### Response Format + +```json +{ + "id": "modr-8sFEN22QCziALOfWTa77TodNLgHwA", + "model": "text-moderation-007", + "results": [ + { + "categories": { + "harassment": false, + "harassment/threatening": false, + "hate": false, + "hate/threatening": false, + "self-harm": false, + "self-harm/instructions": false, + "self-harm/intent": false, + "sexual": false, + "sexual/minors": false, + "violence": false, + "violence/graphic": false + }, + "category_scores": { + "harassment": 0.000019947197870351374, + "harassment/threatening": 5.5971017900446896e-6, + "hate": 0.000028560316422954202, + "hate/threatening": 2.2631787999216613e-8, + "self-harm": 2.9121162015144364e-7, + "self-harm/instructions": 9.314219084899378e-8, + "self-harm/intent": 8.093739012338119e-8, + "sexual": 0.00004414955765241757, + "sexual/minors": 0.0000156943697220413, + "violence": 0.00022354527027346194, + "violence/graphic": 8.804164281173144e-6 + }, + "flagged": false + } + ] +} +``` + ## Advanced From 573515ec3d9aded7ee76b59cad793bd5057effe8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:34:18 -0800 Subject: [PATCH 4/5] (docs) add moderations endpoint to docs --- litellm/proxy/proxy_server.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9b99d08c7..23c25b081 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2814,6 +2814,17 @@ async def moderations( request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies. + + Quick Start + ``` + curl --location 'http://0.0.0.0:4000/moderations' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{"input": "Sample text goes here", "model": "text-moderation-stable"}' + ``` + """ global proxy_logging_obj try: # Use orjson to parse JSON data, orjson speeds up requests significantly From a575efb4eeb95d2686e60cfffa52506a108698c9 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:40:58 -0800 Subject: [PATCH 5/5] (fix) fix moderation test --- litellm/main.py | 7 +++++-- litellm/proxy/proxy_config.yaml | 4 ++++ litellm/tests/test_completion.py | 4 ---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 585de7fff..352ce1882 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2962,8 +2962,11 @@ def text_completion( ##### Moderation ####################### -@client -def moderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): + + +def moderation( + input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs +): # only supports open ai for now api_key = ( api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 8d35bcae8..d94f987db 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -32,6 +32,10 @@ model_list: api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault model_info: base_model: azure/gpt-4 + - model_name: text-moderation-stable + litellm_params: + model: text-moderation-stable + api_key: os.environ/OPENAI_API_KEY litellm_settings: fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}] success_callback: ['langfuse'] diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e93b00ef6..17ced7382 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2093,10 +2093,6 @@ def test_completion_cloudflare(): def test_moderation(): - import openai - - openai.api_type = "azure" - openai.api_version = "GM" response = litellm.moderation(input="i'm ishaan cto of litellm") print(response) output = response.results[0]