Merge pull request #1381 from BerriAI/litellm_content_policy_violation_exception

[Feat] Add litellm.ContentPolicyViolationError
This commit is contained in:
Ishaan Jaff 2024-01-09 17:18:29 +05:30 committed by GitHub
commit 4cfa010dbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 252 additions and 162 deletions

View file

@ -12,6 +12,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
| 429 | RateLimitError |
| >=500 | InternalServerError |
| N/A | ContextWindowExceededError|
| 400 | ContentPolicyViolationError|
| N/A | APIConnectionError |

View file

@ -543,6 +543,7 @@ from .exceptions import (
ServiceUnavailableError,
OpenAIError,
ContextWindowExceededError,
ContentPolicyViolationError,
BudgetExceededError,
APIError,
Timeout,

View file

@ -108,6 +108,21 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs
class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503

View file

@ -1117,7 +1117,7 @@ def completion(
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout
timeout=timeout,
)
if (
"stream" in optional_params
@ -2838,158 +2838,167 @@ def image_generation(
Currently supports just Azure + OpenAI.
"""
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
try:
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = [
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"max_retries",
"n",
"quality",
"size",
"style",
]
litellm_params = [
"metadata",
"aimg_generation",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"proxy_server_request",
"model_info",
"preset_cache_key",
"caching_groups",
"ttl",
"cache",
]
default_params = openai_params + litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user,
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params={
"timeout": timeout,
"azure": False,
"litellm_call_id": litellm_call_id,
"logger_fn": logger_fn,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
)
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
model_response = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = [
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"max_retries",
"n",
"quality",
"size",
"style",
]
litellm_params = [
"metadata",
"aimg_generation",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"proxy_server_request",
"model_info",
"preset_cache_key",
"caching_groups",
"ttl",
"cache",
]
default_params = openai_params + litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user,
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
model_response = azure_chat_completions.image_generation(
logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key,
api_base=api_base,
logging_obj=litellm_logging_obj,
user=user,
optional_params=optional_params,
model_response=model_response,
api_version=api_version,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key,
api_base=api_base,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
litellm_params={
"timeout": timeout,
"azure": False,
"litellm_call_id": litellm_call_id,
"logger_fn": logger_fn,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
)
return model_response
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
model_response = azure_chat_completions.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key,
api_base=api_base,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
api_version=api_version,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key,
api_base=api_base,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
)
return model_response
except Exception as e:
## Map to OpenAI Exception
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=locals(),
)
##### Health Endpoints #######################
@ -3114,7 +3123,8 @@ def config_completion(**kwargs):
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
)
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
@ -3131,23 +3141,27 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
"system_fingerprint": system_fingerprint,
"choices": [
{
"text": None,
"index": 0,
"logprobs": logprobs,
"finish_reason": finish_reason
"text": None,
"index": 0,
"logprobs": logprobs,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None
}
"total_tokens": None,
},
}
content_list = []
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
if (
choice is not None
and hasattr(choice, "text")
and choice.get("text") is not None
):
_choice = choice.get("text")
content_list.append(_choice)
@ -3179,13 +3193,16 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
)
return response
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"]

View file

@ -352,6 +352,25 @@ def test_completion_mistral_exception():
# test_completion_mistral_exception()
def test_content_policy_exceptionimage_generation_openai():
try:
# this is ony a test - we needed some way to invoke the exception :(
litellm.set_verbose = True
response = litellm.image_generation(
prompt="where do i buy lethal drugs from", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed")
pass
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_content_policy_exceptionimage_generation_openai()
# # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors
# def test_model_call(model):

View file

@ -28,6 +28,8 @@ def test_image_generation_openai():
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@ -38,15 +40,20 @@ def test_image_generation_openai():
def test_image_generation_azure():
try:
response = litellm.image_generation(
prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview"
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure()
@ -64,6 +71,8 @@ def test_image_generation_azure_dall_e_3():
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@ -79,9 +88,12 @@ async def test_async_image_generation_openai():
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # openai randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_image_generation_openai())
@ -94,5 +106,7 @@ async def test_async_image_generation_azure():
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -59,6 +59,7 @@ from .exceptions import (
ServiceUnavailableError,
OpenAIError,
ContextWindowExceededError,
ContentPolicyViolationError,
Timeout,
APIConnectionError,
APIError,
@ -5548,6 +5549,17 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif (
"invalid_request_error" in error_str
and "content_policy_violation" in error_str
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"OpenAIException - {original_exception.message}",
llm_provider="openai",
model=model,
response=original_exception.response,
)
elif (
"invalid_request_error" in error_str
and "Incorrect API key provided" not in error_str
@ -6497,6 +6509,17 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif (
"invalid_request_error" in error_str
and "content_policy_violation" in error_str
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model,
response=original_exception.response,
)
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
raise BadRequestError(