mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
_is_model_gemini_gemini_spec_model
This commit is contained in:
parent
fb31006cd8
commit
bbe69a47a9
4 changed files with 19 additions and 7 deletions
|
@ -3,6 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
|
@ -28,6 +29,10 @@ def get_supports_system_message(
|
|||
supports_system_message = supports_system_messages(
|
||||
model=model, custom_llm_provider=_custom_llm_provider
|
||||
)
|
||||
|
||||
# Vertex Models called in the `/gemini` request/response format also support system messages
|
||||
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
|
||||
supports_system_message = True
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||
|
|
|
@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
"extra_headers",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs" # Added this to list of supported openAI params
|
||||
"top_logprobs", # Added this to list of supported openAI params
|
||||
]
|
||||
|
||||
def map_tool_choice_values(
|
||||
|
@ -419,6 +419,17 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
"europe-west9",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
|
||||
"""
|
||||
Returns true if user is trying to call custom model in `/gemini` request/response format
|
||||
"""
|
||||
if model is None:
|
||||
return False
|
||||
if "gemini/" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_flagged_finish_reasons(self) -> Dict[str, str]:
|
||||
"""
|
||||
Return Dictionary of finish reasons which indicate response was flagged
|
||||
|
|
|
@ -1097,10 +1097,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=LitellmCoreRequestUtils.select_model_for_request_transformation(
|
||||
model=model,
|
||||
base_model=base_model,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
|
|
|
@ -3380,8 +3380,7 @@ def test_gemini_fine_tuned_model_request_consistency():
|
|||
with patch.object(client, "post", new=MagicMock()) as mock_post_1:
|
||||
try:
|
||||
response_1 = completion(
|
||||
model="vertex_ai/ft-uuid",
|
||||
base_model="vertex_ai/gemini-2.0-flash-001",
|
||||
model="vertex_ai/gemini/ft-uuid",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue