mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
test_gemini_fine_tuned_model_request_consistency
This commit is contained in:
parent
13391f9d7f
commit
826deea6fb
3 changed files with 29 additions and 2 deletions
|
@ -75,8 +75,8 @@ def _get_vertex_url(
|
|||
) -> Tuple[str, str]:
|
||||
url: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
|
||||
model = litellm.VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
||||
|
||||
model = litellm.VertexGeminiConfig.get_model_for_vertex_ai_url(model=model)
|
||||
if mode == "chat":
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint = "generateContent"
|
||||
|
|
|
@ -419,6 +419,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
"europe-west9",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_model_for_vertex_ai_url(model: str) -> str:
|
||||
"""
|
||||
Returns the model name to use in the request to Vertex AI
|
||||
|
||||
Handles 2 cases:
|
||||
1. User passed `model="vertex_ai/gemini/ft-uuid"`, we need to return `ft-uuid` for the request to Vertex AI
|
||||
2. User passed `model="vertex_ai/gemini-2.0-flash-001"`, we need to return `gemini-2.0-flash-001` for the request to Vertex AI
|
||||
|
||||
Args:
|
||||
model (str): The model name to use in the request to Vertex AI
|
||||
|
||||
Returns:
|
||||
str: The model name to use in the request to Vertex AI
|
||||
"""
|
||||
if VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
|
||||
return VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
|
||||
"""
|
||||
|
|
|
@ -3395,6 +3395,14 @@ def test_gemini_fine_tuned_model_request_consistency():
|
|||
first_request_body = mock_post_1.call_args.kwargs["json"]
|
||||
print("first_request_body", first_request_body)
|
||||
|
||||
# Validate correct `model` is added to the request to Vertex AI
|
||||
print("final URL=", mock_post_1.call_args.kwargs["url"])
|
||||
# Validate the request url
|
||||
assert (
|
||||
"publishers/google/models/ft-uuid:generateContent"
|
||||
in mock_post_1.call_args.kwargs["url"]
|
||||
)
|
||||
|
||||
# Second request
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_post_2:
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue