From 693efc8e849b2cd44b4acffe6a6e0f4affe3d6e4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 11:00:09 -0800 Subject: [PATCH] (feat) add moderation on router --- litellm/main.py | 32 ++++++++++--- litellm/router.py | 92 ++++++++++++++++++++++++++++++++++++ litellm/tests/test_router.py | 20 ++++++++ litellm/utils.py | 7 +++ 4 files changed, 145 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index a7990ecfb..585de7fff 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2962,16 +2962,36 @@ def text_completion( ##### Moderation ####################### -def moderation(input: str, api_key: Optional[str] = None): +@client +def moderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): # only supports open ai for now api_key = ( api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") ) - openai.api_key = api_key - openai.api_type = "open_ai" # type: ignore - openai.api_version = None - openai.base_url = "https://api.openai.com/v1/" - response = openai.moderations.create(input=input) + + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.OpenAI( + api_key=api_key, + ) + + response = openai_client.moderations.create(input=input, model=model) + return response + + +##### Moderation ####################### +@client +async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): + # only supports open ai for now + api_key = ( + api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + ) + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.AsyncOpenAI( + api_key=api_key, + ) + response = await openai_client.moderations.create(input=input, model=model) return response diff --git a/litellm/router.py b/litellm/router.py index 21e967576..b64b111a1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -599,6 +599,98 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def amoderation(self, model: str, input: str, **kwargs): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._amoderation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _amoderation(self, model: str, input: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _moderation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + timeout = ( + data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or kwargs.get( + "timeout", None + ) # this uses default_litellm_params when nothing is set + ) + + response = await litellm.amoderation( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + "timeout": timeout, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + def text_completion( self, model: str, diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b9ca29cee..ab329e14a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -991,3 +991,23 @@ def test_router_timeout(): print(e) print(vars(e)) pass + + +@pytest.mark.asyncio +async def test_router_amoderation(): + model_list = [ + { + "model_name": "openai-moderations", + "litellm_params": { + "model": "text-moderation-stable", + "api_key": os.getenv("OPENAI_API_KEY", None), + }, + } + ] + + router = Router(model_list=model_list) + result = await router.amoderation( + model="openai-moderations", input="this is valid good text" + ) + + print("moderation result", result) diff --git a/litellm/utils.py b/litellm/utils.py index e238b84d7..b15be366d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -738,6 +738,8 @@ class CallTypes(Enum): text_completion = "text_completion" image_generation = "image_generation" aimage_generation = "aimage_generation" + moderation = "moderation" + amoderation = "amoderation" # Logging function -> log the exact model details + what's being sent | Non-BlockingP @@ -2100,6 +2102,11 @@ def client(original_function): or call_type == CallTypes.aimage_generation.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.moderation.value + or call_type == CallTypes.amoderation.value + ): + messages = args[1] if len(args) > 1 else kwargs["input"] elif ( call_type == CallTypes.atext_completion.value or call_type == CallTypes.text_completion.value