mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
test_gemini_fine_tuned_model_request_consistency
This commit is contained in:
parent
13391f9d7f
commit
826deea6fb
3 changed files with 29 additions and 2 deletions
|
@ -75,8 +75,8 @@ def _get_vertex_url(
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
endpoint: Optional[str] = None
|
endpoint: Optional[str] = None
|
||||||
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
|
|
||||||
model = litellm.VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
model = litellm.VertexGeminiConfig.get_model_for_vertex_ai_url(model=model)
|
||||||
if mode == "chat":
|
if mode == "chat":
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
endpoint = "generateContent"
|
endpoint = "generateContent"
|
||||||
|
|
|
@ -419,6 +419,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
"europe-west9",
|
"europe-west9",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_for_vertex_ai_url(model: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns the model name to use in the request to Vertex AI
|
||||||
|
|
||||||
|
Handles 2 cases:
|
||||||
|
1. User passed `model="vertex_ai/gemini/ft-uuid"`, we need to return `ft-uuid` for the request to Vertex AI
|
||||||
|
2. User passed `model="vertex_ai/gemini-2.0-flash-001"`, we need to return `gemini-2.0-flash-001` for the request to Vertex AI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The model name to use in the request to Vertex AI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The model name to use in the request to Vertex AI
|
||||||
|
"""
|
||||||
|
if VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
|
||||||
|
return VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
||||||
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
|
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3395,6 +3395,14 @@ def test_gemini_fine_tuned_model_request_consistency():
|
||||||
first_request_body = mock_post_1.call_args.kwargs["json"]
|
first_request_body = mock_post_1.call_args.kwargs["json"]
|
||||||
print("first_request_body", first_request_body)
|
print("first_request_body", first_request_body)
|
||||||
|
|
||||||
|
# Validate correct `model` is added to the request to Vertex AI
|
||||||
|
print("final URL=", mock_post_1.call_args.kwargs["url"])
|
||||||
|
# Validate the request url
|
||||||
|
assert (
|
||||||
|
"publishers/google/models/ft-uuid:generateContent"
|
||||||
|
in mock_post_1.call_args.kwargs["url"]
|
||||||
|
)
|
||||||
|
|
||||||
# Second request
|
# Second request
|
||||||
with patch.object(client, "post", new=MagicMock()) as mock_post_2:
|
with patch.object(client, "post", new=MagicMock()) as mock_post_2:
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue