test_gemini_fine_tuned_model_request_consistency

This commit is contained in:
Ishaan Jaff 2025-03-26 14:18:11 -07:00
parent 13391f9d7f
commit 826deea6fb
3 changed files with 29 additions and 2 deletions

View file

@ -75,8 +75,8 @@ def _get_vertex_url(
) -> Tuple[str, str]: ) -> Tuple[str, str]:
url: Optional[str] = None url: Optional[str] = None
endpoint: Optional[str] = None endpoint: Optional[str] = None
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
model = litellm.VertexGeminiConfig._get_model_name_from_gemini_spec_model(model) model = litellm.VertexGeminiConfig.get_model_for_vertex_ai_url(model=model)
if mode == "chat": if mode == "chat":
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint = "generateContent" endpoint = "generateContent"

View file

@ -419,6 +419,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"europe-west9", "europe-west9",
] ]
@staticmethod
def get_model_for_vertex_ai_url(model: str) -> str:
"""
Returns the model name to use in the request to Vertex AI
Handles 2 cases:
1. User passed `model="vertex_ai/gemini/ft-uuid"`, we need to return `ft-uuid` for the request to Vertex AI
2. User passed `model="vertex_ai/gemini-2.0-flash-001"`, we need to return `gemini-2.0-flash-001` for the request to Vertex AI
Args:
model (str): The model name to use in the request to Vertex AI
Returns:
str: The model name to use in the request to Vertex AI
"""
if VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
return VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
return model
@staticmethod @staticmethod
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool: def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
""" """

View file

@ -3395,6 +3395,14 @@ def test_gemini_fine_tuned_model_request_consistency():
first_request_body = mock_post_1.call_args.kwargs["json"] first_request_body = mock_post_1.call_args.kwargs["json"]
print("first_request_body", first_request_body) print("first_request_body", first_request_body)
# Validate correct `model` is added to the request to Vertex AI
print("final URL=", mock_post_1.call_args.kwargs["url"])
# Validate the request url
assert (
"publishers/google/models/ft-uuid:generateContent"
in mock_post_1.call_args.kwargs["url"]
)
# Second request # Second request
with patch.object(client, "post", new=MagicMock()) as mock_post_2: with patch.object(client, "post", new=MagicMock()) as mock_post_2:
try: try: