forked from phoenix/litellm-mirror
Merge pull request #1381 from BerriAI/litellm_content_policy_violation_exception
[Feat] Add litellm.ContentPolicyViolationError
This commit is contained in:
commit
4cfa010dbd
7 changed files with 252 additions and 162 deletions
|
@ -12,6 +12,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
||||||
| 429 | RateLimitError |
|
| 429 | RateLimitError |
|
||||||
| >=500 | InternalServerError |
|
| >=500 | InternalServerError |
|
||||||
| N/A | ContextWindowExceededError|
|
| N/A | ContextWindowExceededError|
|
||||||
|
| 400 | ContentPolicyViolationError|
|
||||||
| N/A | APIConnectionError |
|
| N/A | APIConnectionError |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -543,6 +543,7 @@ from .exceptions import (
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
|
ContentPolicyViolationError,
|
||||||
BudgetExceededError,
|
BudgetExceededError,
|
||||||
APIError,
|
APIError,
|
||||||
Timeout,
|
Timeout,
|
||||||
|
|
|
@ -108,6 +108,21 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||||
|
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 400
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
message=self.message,
|
||||||
|
model=self.model, # type: ignore
|
||||||
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
|
response=response,
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class ServiceUnavailableError(APIStatusError): # type: ignore
|
class ServiceUnavailableError(APIStatusError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 503
|
self.status_code = 503
|
||||||
|
|
329
litellm/main.py
329
litellm/main.py
|
@ -1117,7 +1117,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
timeout=timeout
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -2838,158 +2838,167 @@ def image_generation(
|
||||||
|
|
||||||
Currently supports just Azure + OpenAI.
|
Currently supports just Azure + OpenAI.
|
||||||
"""
|
"""
|
||||||
aimg_generation = kwargs.get("aimg_generation", False)
|
try:
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
aimg_generation = kwargs.get("aimg_generation", False)
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
model_info = kwargs.get("model_info", None)
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
|
||||||
model_response = litellm.utils.ImageResponse()
|
model_response = litellm.utils.ImageResponse()
|
||||||
if model is not None or custom_llm_provider is not None:
|
if model is not None or custom_llm_provider is not None:
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
else:
|
else:
|
||||||
model = "dall-e-2"
|
model = "dall-e-2"
|
||||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"user",
|
"user",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
"api_base",
|
"api_base",
|
||||||
"api_version",
|
"api_version",
|
||||||
"api_key",
|
"api_key",
|
||||||
"deployment_id",
|
"deployment_id",
|
||||||
"organization",
|
"organization",
|
||||||
"base_url",
|
"base_url",
|
||||||
"default_headers",
|
"default_headers",
|
||||||
"timeout",
|
"timeout",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
"n",
|
"n",
|
||||||
"quality",
|
"quality",
|
||||||
"size",
|
"size",
|
||||||
"style",
|
"style",
|
||||||
]
|
]
|
||||||
litellm_params = [
|
litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
"aimg_generation",
|
"aimg_generation",
|
||||||
"caching",
|
"caching",
|
||||||
"mock_response",
|
"mock_response",
|
||||||
"api_key",
|
"api_key",
|
||||||
"api_version",
|
"api_version",
|
||||||
"api_base",
|
"api_base",
|
||||||
"force_timeout",
|
"force_timeout",
|
||||||
"logger_fn",
|
"logger_fn",
|
||||||
"verbose",
|
"verbose",
|
||||||
"custom_llm_provider",
|
"custom_llm_provider",
|
||||||
"litellm_logging_obj",
|
"litellm_logging_obj",
|
||||||
"litellm_call_id",
|
"litellm_call_id",
|
||||||
"use_client",
|
"use_client",
|
||||||
"id",
|
"id",
|
||||||
"fallbacks",
|
"fallbacks",
|
||||||
"azure",
|
"azure",
|
||||||
"headers",
|
"headers",
|
||||||
"model_list",
|
"model_list",
|
||||||
"num_retries",
|
"num_retries",
|
||||||
"context_window_fallback_dict",
|
"context_window_fallback_dict",
|
||||||
"roles",
|
"roles",
|
||||||
"final_prompt_value",
|
"final_prompt_value",
|
||||||
"bos_token",
|
"bos_token",
|
||||||
"eos_token",
|
"eos_token",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
"complete_response",
|
"complete_response",
|
||||||
"self",
|
"self",
|
||||||
"client",
|
"client",
|
||||||
"rpm",
|
"rpm",
|
||||||
"tpm",
|
"tpm",
|
||||||
"input_cost_per_token",
|
"input_cost_per_token",
|
||||||
"output_cost_per_token",
|
"output_cost_per_token",
|
||||||
"hf_model_name",
|
"hf_model_name",
|
||||||
"proxy_server_request",
|
"proxy_server_request",
|
||||||
"model_info",
|
"model_info",
|
||||||
"preset_cache_key",
|
"preset_cache_key",
|
||||||
"caching_groups",
|
"caching_groups",
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
optional_params = get_optional_params_image_gen(
|
optional_params = get_optional_params_image_gen(
|
||||||
n=n,
|
n=n,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
size=size,
|
size=size,
|
||||||
style=style,
|
style=style,
|
||||||
user=user,
|
user=user,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
|
||||||
logging = litellm_logging_obj
|
|
||||||
logging.update_environment_variables(
|
|
||||||
model=model,
|
|
||||||
user=user,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params={
|
|
||||||
"timeout": timeout,
|
|
||||||
"azure": False,
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"logger_fn": logger_fn,
|
|
||||||
"proxy_server_request": proxy_server_request,
|
|
||||||
"model_info": model_info,
|
|
||||||
"metadata": metadata,
|
|
||||||
"preset_cache_key": None,
|
|
||||||
"stream_response": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
|
||||||
# azure configs
|
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
|
||||||
|
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
|
||||||
|
|
||||||
api_version = (
|
|
||||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
|
||||||
)
|
)
|
||||||
|
logging = litellm_logging_obj
|
||||||
api_key = (
|
logging.update_environment_variables(
|
||||||
api_key
|
|
||||||
or litellm.api_key
|
|
||||||
or litellm.azure_key
|
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
|
||||||
or get_secret("AZURE_API_KEY")
|
|
||||||
)
|
|
||||||
|
|
||||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
|
||||||
"AZURE_AD_TOKEN"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response = azure_chat_completions.image_generation(
|
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
user=user,
|
||||||
timeout=timeout,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
logging_obj=litellm_logging_obj,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
litellm_params={
|
||||||
api_version=api_version,
|
"timeout": timeout,
|
||||||
aimg_generation=aimg_generation,
|
"azure": False,
|
||||||
)
|
"litellm_call_id": litellm_call_id,
|
||||||
elif custom_llm_provider == "openai":
|
"logger_fn": logger_fn,
|
||||||
model_response = openai_chat_completions.image_generation(
|
"proxy_server_request": proxy_server_request,
|
||||||
model=model,
|
"model_info": model_info,
|
||||||
prompt=prompt,
|
"metadata": metadata,
|
||||||
timeout=timeout,
|
"preset_cache_key": None,
|
||||||
api_key=api_key,
|
"stream_response": {},
|
||||||
api_base=api_base,
|
},
|
||||||
logging_obj=litellm_logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
model_response=model_response,
|
|
||||||
aimg_generation=aimg_generation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_response
|
if custom_llm_provider == "azure":
|
||||||
|
# azure configs
|
||||||
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||||
|
"AZURE_AD_TOKEN"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response = azure_chat_completions.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
api_version=api_version,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "openai":
|
||||||
|
model_response = openai_chat_completions.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
## Map to OpenAI Exception
|
||||||
|
raise exception_type(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs=locals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
##### Health Endpoints #######################
|
##### Health Endpoints #######################
|
||||||
|
@ -3114,7 +3123,8 @@ def config_completion(**kwargs):
|
||||||
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
|
|
||||||
|
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
|
||||||
id = chunks[0]["id"]
|
id = chunks[0]["id"]
|
||||||
object = chunks[0]["object"]
|
object = chunks[0]["object"]
|
||||||
created = chunks[0]["created"]
|
created = chunks[0]["created"]
|
||||||
|
@ -3131,23 +3141,27 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
||||||
"system_fingerprint": system_fingerprint,
|
"system_fingerprint": system_fingerprint,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": None,
|
"text": None,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
"finish_reason": finish_reason
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": None,
|
"prompt_tokens": None,
|
||||||
"completion_tokens": None,
|
"completion_tokens": None,
|
||||||
"total_tokens": None
|
"total_tokens": None,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
content_list = []
|
content_list = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
|
if (
|
||||||
|
choice is not None
|
||||||
|
and hasattr(choice, "text")
|
||||||
|
and choice.get("text") is not None
|
||||||
|
):
|
||||||
_choice = choice.get("text")
|
_choice = choice.get("text")
|
||||||
content_list.append(_choice)
|
content_list.append(_choice)
|
||||||
|
|
||||||
|
@ -3179,13 +3193,16 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||||
id = chunks[0]["id"]
|
id = chunks[0]["id"]
|
||||||
object = chunks[0]["object"]
|
object = chunks[0]["object"]
|
||||||
created = chunks[0]["created"]
|
created = chunks[0]["created"]
|
||||||
model = chunks[0]["model"]
|
model = chunks[0]["model"]
|
||||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
||||||
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
|
if isinstance(
|
||||||
|
chunks[0]["choices"][0], litellm.utils.TextChoices
|
||||||
|
): # route to the text completion logic
|
||||||
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
||||||
role = chunks[0]["choices"][0]["delta"]["role"]
|
role = chunks[0]["choices"][0]["delta"]["role"]
|
||||||
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
||||||
|
|
|
@ -352,6 +352,25 @@ def test_completion_mistral_exception():
|
||||||
# test_completion_mistral_exception()
|
# test_completion_mistral_exception()
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_policy_exceptionimage_generation_openai():
|
||||||
|
try:
|
||||||
|
# this is ony a test - we needed some way to invoke the exception :(
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.image_generation(
|
||||||
|
prompt="where do i buy lethal drugs from", model="dall-e-3"
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert len(response.data) > 0
|
||||||
|
except litellm.ContentPolicyViolationError as e:
|
||||||
|
print("caught a content policy violation error! Passed")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# test_content_policy_exceptionimage_generation_openai()
|
||||||
|
|
||||||
|
|
||||||
# # test_invalid_request_error(model="command-nightly")
|
# # test_invalid_request_error(model="command-nightly")
|
||||||
# # Test 3: Rate Limit Errors
|
# # Test 3: Rate Limit Errors
|
||||||
# def test_model_call(model):
|
# def test_model_call(model):
|
||||||
|
|
|
@ -19,7 +19,7 @@ import litellm
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_openai():
|
def test_image_generation_openai():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
prompt="A cute baby sea otter", model="dall-e-3"
|
||||||
|
@ -28,6 +28,8 @@ def test_image_generation_openai():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # OpenAI randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
@ -36,22 +38,27 @@ def test_image_generation_openai():
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure():
|
def test_image_generation_azure():
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview"
|
prompt="A cute baby sea otter",
|
||||||
|
model="azure/",
|
||||||
|
api_version="2023-06-01-preview",
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_azure()
|
# test_image_generation_azure()
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure_dall_e_3():
|
def test_image_generation_azure_dall_e_3():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
|
@ -64,6 +71,8 @@ def test_image_generation_azure_dall_e_3():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # OpenAI randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
@ -71,7 +80,7 @@ def test_image_generation_azure_dall_e_3():
|
||||||
# test_image_generation_azure_dall_e_3()
|
# test_image_generation_azure_dall_e_3()
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_image_generation_openai():
|
async def test_async_image_generation_openai():
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
prompt="A cute baby sea otter", model="dall-e-3"
|
||||||
)
|
)
|
||||||
|
@ -79,20 +88,25 @@ async def test_async_image_generation_openai():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # openai randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_image_generation_openai())
|
# asyncio.run(test_async_image_generation_openai())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_image_generation_azure():
|
async def test_async_image_generation_azure():
|
||||||
try:
|
try:
|
||||||
response = await litellm.aimage_generation(
|
response = await litellm.aimage_generation(
|
||||||
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
|
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
|
@ -59,6 +59,7 @@ from .exceptions import (
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
|
ContentPolicyViolationError,
|
||||||
Timeout,
|
Timeout,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
|
@ -5548,6 +5549,17 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
"invalid_request_error" in error_str
|
||||||
|
and "content_policy_violation" in error_str
|
||||||
|
):
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContentPolicyViolationError(
|
||||||
|
message=f"OpenAIException - {original_exception.message}",
|
||||||
|
llm_provider="openai",
|
||||||
|
model=model,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"invalid_request_error" in error_str
|
"invalid_request_error" in error_str
|
||||||
and "Incorrect API key provided" not in error_str
|
and "Incorrect API key provided" not in error_str
|
||||||
|
@ -6497,6 +6509,17 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
"invalid_request_error" in error_str
|
||||||
|
and "content_policy_violation" in error_str
|
||||||
|
):
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContentPolicyViolationError(
|
||||||
|
message=f"AzureException - {original_exception.message}",
|
||||||
|
llm_provider="azure",
|
||||||
|
model=model,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif "invalid_request_error" in error_str:
|
elif "invalid_request_error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue