diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 25993d6d5b..7134c3536e 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3327,3 +3327,96 @@ def test_signed_s3_url_with_format(): json_str = json.dumps(mock_client.call_args.kwargs["json"]) assert "image/jpeg" in json_str assert "image/png" not in json_str + + +def test_gemini_fine_tuned_model_request_consistency(): + """ + Assert the same transformation is applied to Fine tuned gemini 2.0 flash and gemini 2.0 flash + + - Request 1: Fine tuned: vertex_ai/ft-uuid with base_model: vertex_ai/gemini-2.0-flash-001 + - Request 2: vertex_ai/gemini-2.0-flash-001 + """ + litellm.set_verbose = True + load_vertex_ai_credentials() + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from unittest.mock import patch, MagicMock + + # Set up the messages + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + # Define tools + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + } + ] + + client = HTTPHandler(concurrent_limit=1) + + # First request + with patch.object(client, "post", new=MagicMock()) as mock_post_1: + try: + response_1 = completion( + model="vertex_ai/ft-uuid", + base_model="vertex_ai/gemini-2.0-flash-001", + messages=messages, + tools=tools, + tool_choice="auto", + client=client, + ) + + except Exception as e: + print(e) + + # Store the request body from the first call + first_request_body = mock_post_1.call_args.kwargs["json"] + print("first_request_body", first_request_body) + + # Second request + with patch.object(client, "post", new=MagicMock()) as mock_post_2: + try: + response_2 = completion( + model="vertex_ai/gemini-2.0-flash-001", + messages=messages, + tools=tools, + tool_choice="auto", + client=client, + ) + except Exception as e: + print(e) + + # Store the request body from the second call + second_request_body = mock_post_2.call_args.kwargs["json"] + print("second_request_body", second_request_body) + + # Get the diff between the two request bodies + # Convert dictionaries to formatted JSON strings + import json + + first_json = json.dumps(first_request_body, indent=2).splitlines() + second_json = json.dumps(second_request_body, indent=2).splitlines() + # Assert there is no difference between the request bodies + assert first_json == second_json, "Request bodies should be identical"