fix(vertex_httpx.py): Return empty model response for content filter violations

This commit is contained in:
Krrish Dholakia 2024-06-24 19:22:20 -07:00
parent 1ff0129a94
commit 8e6e5a6d37
2 changed files with 33 additions and 14 deletions

View file

@ -562,6 +562,9 @@ class VertexLLM(BaseLLM):
status_code=422, status_code=422,
) )
## GET MODEL ##
model_response.model = model
## CHECK IF RESPONSE FLAGGED ## CHECK IF RESPONSE FLAGGED
if "promptFeedback" in completion_response: if "promptFeedback" in completion_response:
if "blockReason" in completion_response["promptFeedback"]: if "blockReason" in completion_response["promptFeedback"]:
@ -646,9 +649,6 @@ class VertexLLM(BaseLLM):
model_response.choices = [] # type: ignore model_response.choices = [] # type: ignore
## GET MODEL ##
model_response.model = model
try: try:
## GET TEXT ## ## GET TEXT ##
chat_completion_message = {"role": "assistant"} chat_completion_message = {"role": "assistant"}

View file

@ -696,6 +696,18 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
pytest.fail("An unexpected exception occurred - {}".format(str(e))) pytest.fail("An unexpected exception occurred - {}".format(str(e)))
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"promptFeedback": {"blockReason": "OTHER"},
"usageMetadata": {"promptTokenCount": 6285, "totalTokenCount": 6285},
}
return mock_response
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
def vertex_httpx_mock_post(url, data=None, json=None, headers=None): def vertex_httpx_mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock() mock_response = MagicMock()
@ -817,8 +829,11 @@ def vertex_httpx_mock_post(url, data=None, json=None, headers=None):
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.parametrize("content_filter_type", ["prompt", "response"]) # "vertex_ai",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_json_schema_httpx_content_policy_error(provider): async def test_gemini_pro_json_schema_httpx_content_policy_error(
provider, content_filter_type
):
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
@ -839,16 +854,20 @@ Using this JSON schema:
client = HTTPHandler() client = HTTPHandler()
with patch.object(client, "post", side_effect=vertex_httpx_mock_post) as mock_call: if content_filter_type == "prompt":
try: _side_effect = vertex_httpx_mock_reject_prompt_post
response = completion( else:
model="vertex_ai_beta/gemini-1.5-flash", _side_effect = vertex_httpx_mock_post
messages=messages,
response_format={"type": "json_object"}, with patch.object(client, "post", side_effect=_side_effect) as mock_call:
client=client, response = completion(
) model="vertex_ai_beta/gemini-1.5-flash",
except litellm.ContentPolicyViolationError as e: messages=messages,
pass response_format={"type": "json_object"},
client=client,
)
assert response.choices[0].finish_reason == "content_filter"
mock_call.assert_called_once() mock_call.assert_called_once()