Merge pull request #4342 from BerriAI/litellm_fix_moderations

fix - liteLLM proxy /moderations endpoint returns 500 error when model is not specified
This commit is contained in:
Ishaan Jaff 2024-06-21 16:19:22 -07:00 committed by GitHub
commit c0540e764d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 57 additions and 15 deletions

View file

@ -600,7 +600,7 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = None,
max_retries: Optional[int] = 2,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):

View file

@ -3852,14 +3852,20 @@ def moderation(
@client
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
async def amoderation(
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
):
# only supports open ai for now
api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
)
openai_client = kwargs.get("client", None)
if openai_client is None:
openai_client = openai.AsyncOpenAI(
# call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients
openai_client = openai_chat_completions._get_openai_client(
is_async=True,
api_key=api_key,
)
response = await openai_client.moderations.create(input=input, model=model)

View file

@ -4947,7 +4947,7 @@ async def moderations(
data["model"] = (
general_settings.get("moderation_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
or data.get("model") # default passed in http request
)
if user_model:
data["model"] = user_model
@ -4966,37 +4966,33 @@ async def moderations(
if "api_key" in data:
response = await litellm.amoderation(**data)
elif (
llm_router is not None and data["model"] in router_model_names
llm_router is not None and data.get("model") in router_model_names
): # model in router model list
response = await llm_router.amoderation(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
llm_router is not None and data.get("model") in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
and data.get("model") in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and data.get("model") not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "moderations: Invalid model name passed in model="
+ data.get("model", "")
},
)
# /moderations does not need a "model" passed
# see https://platform.openai.com/docs/api-reference/moderations
response = await litellm.amoderation(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting

View file

@ -73,6 +73,27 @@ async def new_user(session):
return await response.json()
async def moderation(session, key):
url = "http://0.0.0.0:4000/moderations"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {"input": "I want to kill the cat."}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
@ -465,3 +486,22 @@ async def test_batch_chat_completions():
assert len(response) == 2
assert isinstance(response, list)
@pytest.mark.asyncio
async def test_moderations_endpoint():
"""
- Make chat completion call using
"""
async with aiohttp.ClientSession() as session:
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
response = await moderation(
session=session,
key="sk-1234",
)
print(f"response: {response}")
assert "results" in response