diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 028c3f7217..856b05f61c 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -562,6 +562,9 @@ class VertexLLM(BaseLLM): status_code=422, ) + ## GET MODEL ## + model_response.model = model + ## CHECK IF RESPONSE FLAGGED if "promptFeedback" in completion_response: if "blockReason" in completion_response["promptFeedback"]: @@ -646,9 +649,6 @@ class VertexLLM(BaseLLM): model_response.choices = [] # type: ignore - ## GET MODEL ## - model_response.model = model - try: ## GET TEXT ## chat_completion_message = {"role": "assistant"} diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index fb28912493..c9e5501a8c 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -696,6 +696,18 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "promptFeedback": {"blockReason": "OTHER"}, + "usageMetadata": {"promptTokenCount": 6285, "totalTokenCount": 6285}, + } + + return mock_response + + # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") def vertex_httpx_mock_post(url, data=None, json=None, headers=None): mock_response = MagicMock() @@ -817,8 +829,11 @@ def vertex_httpx_mock_post(url, data=None, json=None, headers=None): @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize("content_filter_type", ["prompt", "response"]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_json_schema_httpx_content_policy_error(provider): +async def test_gemini_pro_json_schema_httpx_content_policy_error( + provider, content_filter_type +): load_vertex_ai_credentials() litellm.set_verbose = True messages = [ @@ -839,16 +854,20 @@ Using this JSON schema: client = HTTPHandler() - with patch.object(client, "post", side_effect=vertex_httpx_mock_post) as mock_call: - try: - response = completion( - model="vertex_ai_beta/gemini-1.5-flash", - messages=messages, - response_format={"type": "json_object"}, - client=client, - ) - except litellm.ContentPolicyViolationError as e: - pass + if content_filter_type == "prompt": + _side_effect = vertex_httpx_mock_reject_prompt_post + else: + _side_effect = vertex_httpx_mock_post + + with patch.object(client, "post", side_effect=_side_effect) as mock_call: + response = completion( + model="vertex_ai_beta/gemini-1.5-flash", + messages=messages, + response_format={"type": "json_object"}, + client=client, + ) + + assert response.choices[0].finish_reason == "content_filter" mock_call.assert_called_once()