diff --git a/litellm/main.py b/litellm/main.py index 928fc47d9e..de1874b8b8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -67,6 +67,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import ( from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.realtime_api.main import _realtime_health_check from litellm.secret_managers.main import get_secret_str +from litellm.types.router import GenericLiteLLMParams from litellm.utils import ( CustomStreamWrapper, Usage, @@ -4314,7 +4315,11 @@ def moderation( @client async def amoderation( - input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs + input: str, + model: Optional[str] = None, + api_key: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, ): from openai import AsyncOpenAI @@ -4335,6 +4340,20 @@ async def amoderation( ) else: _openai_client = openai_client + + optional_params = GenericLiteLLMParams(**kwargs) + try: + model, _custom_llm_provider, _dynamic_api_key, _dynamic_api_base = ( + litellm.get_llm_provider( + model=model or "", + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, + ) + ) + except litellm.BadRequestError: + # `model` is optional field for moderation - get_llm_provider will throw BadRequestError if model is not set / not recognized + pass if model is not None: response = await _openai_client.moderations.create(input=input, model=model) else: @@ -5095,7 +5114,6 @@ def speech( aspeech=aspeech, ) elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": - from litellm.types.router import GenericLiteLLMParams generic_optional_params = GenericLiteLLMParams(**kwargs) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index dd151e36e4..3890624232 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,34 +1,11 @@ model_list: - - model_name: openai/* + - model_name: "*" litellm_params: - model: openai/* + model: "openai/*" api_key: os.environ/OPENAI_API_KEY - - model_name: anthropic/* + - model_name: "openai/*" litellm_params: - model: anthropic/* - api_key: os.environ/ANTHROPIC_API_KEY - - model_name: bedrock/* - litellm_params: - model: bedrock/* - -guardrails: - - guardrail_name: "bedrock-pre-guard" - litellm_params: - guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" - mode: "during_call" - guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock - guardrailVersion: "DRAFT" # your guardrail version on bedrock - mode: "post_call" - guardrailIdentifier: ff6ujrregl1q - guardrailVersion: "DRAFT" - guardrail_info: - params: - - name: "toxicity_score" - type: "float" - description: "Score between 0-1 indicating content toxicity level" - - name: "pii_detection" - type: "boolean" - - + model: "openai/*" + api_key: os.environ/OPENAI_API_KEY litellm_settings: callbacks: ["datadog"] \ No newline at end of file diff --git a/tests/router_unit_tests/test_router_endpoints.py b/tests/router_unit_tests/test_router_endpoints.py index e876d37662..98d8f8f90b 100644 --- a/tests/router_unit_tests/test_router_endpoints.py +++ b/tests/router_unit_tests/test_router_endpoints.py @@ -228,6 +228,37 @@ async def test_rerank_endpoint(model_list): RerankResponse.model_validate(response) +@pytest.mark.asyncio() +@pytest.mark.parametrize( + "model", ["omni-moderation-latest", "openai/omni-moderation-latest", None] +) +async def test_moderation_endpoint(model): + litellm.set_verbose = True + router = Router( + model_list=[ + { + "model_name": "openai/*", + "litellm_params": { + "model": "openai/*", + }, + }, + { + "model_name": "*", + "litellm_params": { + "model": "openai/*", + }, + }, + ] + ) + + if model is None: + response = await router.amoderation(input="hello this is a test") + else: + response = await router.amoderation(model=model, input="hello this is a test") + + print("moderation response: ", response) + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_aaaaatext_completion_endpoint(model_list, sync_mode):