mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(vertex_httpx.py): Return empty model response for content filter violations
This commit is contained in:
parent
1ff0129a94
commit
8e6e5a6d37
2 changed files with 33 additions and 14 deletions
|
@ -562,6 +562,9 @@ class VertexLLM(BaseLLM):
|
||||||
status_code=422,
|
status_code=422,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## GET MODEL ##
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
## CHECK IF RESPONSE FLAGGED
|
## CHECK IF RESPONSE FLAGGED
|
||||||
if "promptFeedback" in completion_response:
|
if "promptFeedback" in completion_response:
|
||||||
if "blockReason" in completion_response["promptFeedback"]:
|
if "blockReason" in completion_response["promptFeedback"]:
|
||||||
|
@ -646,9 +649,6 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
model_response.choices = [] # type: ignore
|
model_response.choices = [] # type: ignore
|
||||||
|
|
||||||
## GET MODEL ##
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
## GET TEXT ##
|
## GET TEXT ##
|
||||||
chat_completion_message = {"role": "assistant"}
|
chat_completion_message = {"role": "assistant"}
|
||||||
|
|
|
@ -696,6 +696,18 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
||||||
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"promptFeedback": {"blockReason": "OTHER"},
|
||||||
|
"usageMetadata": {"promptTokenCount": 6285, "totalTokenCount": 6285},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
def vertex_httpx_mock_post(url, data=None, json=None, headers=None):
|
def vertex_httpx_mock_post(url, data=None, json=None, headers=None):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
@ -817,8 +829,11 @@ def vertex_httpx_mock_post(url, data=None, json=None, headers=None):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
|
@pytest.mark.parametrize("content_filter_type", ["prompt", "response"]) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_json_schema_httpx_content_policy_error(provider):
|
async def test_gemini_pro_json_schema_httpx_content_policy_error(
|
||||||
|
provider, content_filter_type
|
||||||
|
):
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -839,16 +854,20 @@ Using this JSON schema:
|
||||||
|
|
||||||
client = HTTPHandler()
|
client = HTTPHandler()
|
||||||
|
|
||||||
with patch.object(client, "post", side_effect=vertex_httpx_mock_post) as mock_call:
|
if content_filter_type == "prompt":
|
||||||
try:
|
_side_effect = vertex_httpx_mock_reject_prompt_post
|
||||||
response = completion(
|
else:
|
||||||
model="vertex_ai_beta/gemini-1.5-flash",
|
_side_effect = vertex_httpx_mock_post
|
||||||
messages=messages,
|
|
||||||
response_format={"type": "json_object"},
|
with patch.object(client, "post", side_effect=_side_effect) as mock_call:
|
||||||
client=client,
|
response = completion(
|
||||||
)
|
model="vertex_ai_beta/gemini-1.5-flash",
|
||||||
except litellm.ContentPolicyViolationError as e:
|
messages=messages,
|
||||||
pass
|
response_format={"type": "json_object"},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].finish_reason == "content_filter"
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue