forked from phoenix/litellm-mirror
Merge pull request #1974 from BerriAI/litellm_proxy_add_moderations_endpoint
[FEAT] Proxy Add /moderations endpoint
This commit is contained in:
commit
ed8f507536
9 changed files with 387 additions and 12 deletions
|
@ -197,7 +197,7 @@ from openai import OpenAI
|
|||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||
|
||||
response = openai.embeddings.create(
|
||||
response = client.embeddings.create(
|
||||
input=["hello from litellm"],
|
||||
model="text-embedding-ada-002"
|
||||
)
|
||||
|
@ -281,6 +281,84 @@ print(query_result[:5])
|
|||
|
||||
```
|
||||
|
||||
## `/moderations`
|
||||
|
||||
|
||||
### Request Format
|
||||
Input, Output and Exceptions are mapped to the OpenAI format for all supported models
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||
|
||||
response = client.moderations.create(
|
||||
input="hello from litellm",
|
||||
model="text-moderation-stable"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/moderations' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "modr-8sFEN22QCziALOfWTa77TodNLgHwA",
|
||||
"model": "text-moderation-007",
|
||||
"results": [
|
||||
{
|
||||
"categories": {
|
||||
"harassment": false,
|
||||
"harassment/threatening": false,
|
||||
"hate": false,
|
||||
"hate/threatening": false,
|
||||
"self-harm": false,
|
||||
"self-harm/instructions": false,
|
||||
"self-harm/intent": false,
|
||||
"sexual": false,
|
||||
"sexual/minors": false,
|
||||
"violence": false,
|
||||
"violence/graphic": false
|
||||
},
|
||||
"category_scores": {
|
||||
"harassment": 0.000019947197870351374,
|
||||
"harassment/threatening": 5.5971017900446896e-6,
|
||||
"hate": 0.000028560316422954202,
|
||||
"hate/threatening": 2.2631787999216613e-8,
|
||||
"self-harm": 2.9121162015144364e-7,
|
||||
"self-harm/instructions": 9.314219084899378e-8,
|
||||
"self-harm/intent": 8.093739012338119e-8,
|
||||
"sexual": 0.00004414955765241757,
|
||||
"sexual/minors": 0.0000156943697220413,
|
||||
"violence": 0.00022354527027346194,
|
||||
"violence/graphic": 8.804164281173144e-6
|
||||
},
|
||||
"flagged": false
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Advanced
|
||||
|
||||
|
|
|
@ -2961,16 +2961,39 @@ def text_completion(
|
|||
|
||||
|
||||
##### Moderation #######################
|
||||
def moderation(input: str, api_key: Optional[str] = None):
|
||||
|
||||
|
||||
def moderation(
|
||||
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||
):
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
)
|
||||
openai.api_key = api_key
|
||||
openai.api_type = "open_ai" # type: ignore
|
||||
openai.api_version = None
|
||||
openai.base_url = "https://api.openai.com/v1/"
|
||||
response = openai.moderations.create(input=input)
|
||||
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None:
|
||||
openai_client = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = openai_client.moderations.create(input=input, model=model)
|
||||
return response
|
||||
|
||||
|
||||
##### Moderation #######################
|
||||
@client
|
||||
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
)
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None:
|
||||
openai_client = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
)
|
||||
response = await openai_client.moderations.create(input=input, model=model)
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,10 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
model_info:
|
||||
base_model: azure/gpt-4
|
||||
- model_name: text-moderation-stable
|
||||
litellm_params:
|
||||
model: text-moderation-stable
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
litellm_settings:
|
||||
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
|
||||
success_callback: ['langfuse']
|
||||
|
|
|
@ -2798,6 +2798,159 @@ async def image_generation(
|
|||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/moderations",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["moderations"],
|
||||
)
|
||||
@router.post(
|
||||
"/moderations",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["moderations"],
|
||||
)
|
||||
async def moderations(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
|
||||
|
||||
Quick Start
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/moderations' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("moderation_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.amoderation(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.amoderation(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.amoderation(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.response_taking_too_long(
|
||||
start_time=start_time, end_time=end_time, type="slow_response"
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
|
||||
|
|
|
@ -93,7 +93,9 @@ class ProxyLogging:
|
|||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
call_type: Literal[
|
||||
"completion", "embeddings", "image_generation", "moderation"
|
||||
],
|
||||
):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
|
|
@ -599,6 +599,98 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def amoderation(self, model: str, input: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["input"] = input
|
||||
kwargs["original_function"] = self._amoderation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _amoderation(self, model: str, input: str, **kwargs):
|
||||
model_name = None
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
input=input,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout = (
|
||||
data.get(
|
||||
"timeout", None
|
||||
) # timeout set on litellm_params for this deployment
|
||||
or self.timeout # timeout set on router
|
||||
or kwargs.get(
|
||||
"timeout", None
|
||||
) # this uses default_litellm_params when nothing is set
|
||||
)
|
||||
|
||||
response = await litellm.amoderation(
|
||||
**{
|
||||
**data,
|
||||
"input": input,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
def text_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -2093,10 +2093,6 @@ def test_completion_cloudflare():
|
|||
|
||||
|
||||
def test_moderation():
|
||||
import openai
|
||||
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "GM"
|
||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||
print(response)
|
||||
output = response.results[0]
|
||||
|
|
|
@ -991,3 +991,23 @@ def test_router_timeout():
|
|||
print(e)
|
||||
print(vars(e))
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_amoderation():
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "openai-moderations",
|
||||
"litellm_params": {
|
||||
"model": "text-moderation-stable",
|
||||
"api_key": os.getenv("OPENAI_API_KEY", None),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
result = await router.amoderation(
|
||||
model="openai-moderations", input="this is valid good text"
|
||||
)
|
||||
|
||||
print("moderation result", result)
|
||||
|
|
|
@ -738,6 +738,8 @@ class CallTypes(Enum):
|
|||
text_completion = "text_completion"
|
||||
image_generation = "image_generation"
|
||||
aimage_generation = "aimage_generation"
|
||||
moderation = "moderation"
|
||||
amoderation = "amoderation"
|
||||
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
||||
|
@ -2100,6 +2102,11 @@ def client(original_function):
|
|||
or call_type == CallTypes.aimage_generation.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.moderation.value
|
||||
or call_type == CallTypes.amoderation.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
or call_type == CallTypes.text_completion.value
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue