forked from phoenix/litellm-mirror
(feat) add moderation on router
This commit is contained in:
parent
e590f47a44
commit
693efc8e84
4 changed files with 145 additions and 6 deletions
|
@ -2962,16 +2962,36 @@ def text_completion(
|
||||||
|
|
||||||
|
|
||||||
##### Moderation #######################
|
##### Moderation #######################
|
||||||
def moderation(input: str, api_key: Optional[str] = None):
|
@client
|
||||||
|
def moderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
|
||||||
# only supports open ai for now
|
# only supports open ai for now
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
openai.api_key = api_key
|
|
||||||
openai.api_type = "open_ai" # type: ignore
|
openai_client = kwargs.get("client", None)
|
||||||
openai.api_version = None
|
if openai_client is None:
|
||||||
openai.base_url = "https://api.openai.com/v1/"
|
openai_client = openai.OpenAI(
|
||||||
response = openai.moderations.create(input=input)
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.moderations.create(input=input, model=model)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
##### Moderation #######################
|
||||||
|
@client
|
||||||
|
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
|
||||||
|
# only supports open ai for now
|
||||||
|
api_key = (
|
||||||
|
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
openai_client = kwargs.get("client", None)
|
||||||
|
if openai_client is None:
|
||||||
|
openai_client = openai.AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
response = await openai_client.moderations.create(input=input, model=model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -599,6 +599,98 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def amoderation(self, model: str, input: str, **kwargs):
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["input"] = input
|
||||||
|
kwargs["original_function"] = self._amoderation
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _amoderation(self, model: str, input: str, **kwargs):
|
||||||
|
model_name = None
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs and v is not None
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
|
if (
|
||||||
|
dynamic_api_key is not None
|
||||||
|
and potential_model_client is not None
|
||||||
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
|
):
|
||||||
|
model_client = None
|
||||||
|
else:
|
||||||
|
model_client = potential_model_client
|
||||||
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
timeout = (
|
||||||
|
data.get(
|
||||||
|
"timeout", None
|
||||||
|
) # timeout set on litellm_params for this deployment
|
||||||
|
or self.timeout # timeout set on router
|
||||||
|
or kwargs.get(
|
||||||
|
"timeout", None
|
||||||
|
) # this uses default_litellm_params when nothing is set
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await litellm.amoderation(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"input": input,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
"timeout": timeout,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model_name is not None:
|
||||||
|
self.fail_calls[model_name] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
def text_completion(
|
def text_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -991,3 +991,23 @@ def test_router_timeout():
|
||||||
print(e)
|
print(e)
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_amoderation():
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-moderations",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "text-moderation-stable",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY", None),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
result = await router.amoderation(
|
||||||
|
model="openai-moderations", input="this is valid good text"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("moderation result", result)
|
||||||
|
|
|
@ -738,6 +738,8 @@ class CallTypes(Enum):
|
||||||
text_completion = "text_completion"
|
text_completion = "text_completion"
|
||||||
image_generation = "image_generation"
|
image_generation = "image_generation"
|
||||||
aimage_generation = "aimage_generation"
|
aimage_generation = "aimage_generation"
|
||||||
|
moderation = "moderation"
|
||||||
|
amoderation = "amoderation"
|
||||||
|
|
||||||
|
|
||||||
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
||||||
|
@ -2100,6 +2102,11 @@ def client(original_function):
|
||||||
or call_type == CallTypes.aimage_generation.value
|
or call_type == CallTypes.aimage_generation.value
|
||||||
):
|
):
|
||||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.moderation.value
|
||||||
|
or call_type == CallTypes.amoderation.value
|
||||||
|
):
|
||||||
|
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||||
elif (
|
elif (
|
||||||
call_type == CallTypes.atext_completion.value
|
call_type == CallTypes.atext_completion.value
|
||||||
or call_type == CallTypes.text_completion.value
|
or call_type == CallTypes.text_completion.value
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue