mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_httpx.py): check if model supports system messages before sending separately
This commit is contained in:
parent
a80520004e
commit
3d9ef689e7
7 changed files with 190 additions and 73 deletions
|
@ -1823,6 +1823,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if model_info.get("supports_system_messages", False) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||
)
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
@ -1838,7 +1864,7 @@ def supports_function_calling(model: str) -> bool:
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_function_calling", False):
|
||||
if model_info.get("supports_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1862,7 +1888,7 @@ def supports_vision(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_vision", False):
|
||||
if model_info.get("supports_vision", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1884,7 +1910,7 @@ def supports_parallel_function_calling(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_parallel_function_calling", False):
|
||||
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -4319,14 +4345,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
)
|
||||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return {
|
||||
"max_tokens": max_tokens, # type: ignore
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "huggingface",
|
||||
"mode": "chat",
|
||||
"supported_openai_params": supported_openai_params,
|
||||
}
|
||||
return ModelInfo(
|
||||
max_tokens=max_tokens, # type: ignore
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
input_cost_per_token=0,
|
||||
output_cost_per_token=0,
|
||||
litellm_provider="huggingface",
|
||||
mode="chat",
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=None,
|
||||
)
|
||||
else:
|
||||
"""
|
||||
Check if: (in order of specificity)
|
||||
|
@ -4361,6 +4390,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
return _model_info
|
||||
elif split_model in litellm.model_cost:
|
||||
_model_info = litellm.model_cost[split_model]
|
||||
|
@ -4375,7 +4419,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return _model_info
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue