forked from phoenix/litellm-mirror
fix(utils.py): new helper function to check if provider/model supports 'response_schema' param
This commit is contained in:
parent
be8a6377f6
commit
5718d1e205
5 changed files with 114 additions and 93 deletions
|
@ -1486,6 +1486,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-001": {
|
"gemini-1.5-pro-001": {
|
||||||
|
@ -1511,6 +1512,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0514": {
|
"gemini-1.5-pro-preview-0514": {
|
||||||
|
@ -2007,6 +2009,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini/gemini-1.5-pro-latest": {
|
"gemini/gemini-1.5-pro-latest": {
|
||||||
|
@ -2023,6 +2026,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://ai.google.dev/models/gemini"
|
"source": "https://ai.google.dev/models/gemini"
|
||||||
},
|
},
|
||||||
"gemini/gemini-pro-vision": {
|
"gemini/gemini-pro-vision": {
|
||||||
|
|
|
@ -663,3 +663,29 @@ def test_convert_model_response_object():
|
||||||
e.message
|
e.message
|
||||||
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_bool",
|
||||||
|
[
|
||||||
|
("vertex_ai/gemini-1.5-pro", True),
|
||||||
|
("gemini/gemini-1.5-pro", True),
|
||||||
|
("predibase/llama3-8b-instruct", True),
|
||||||
|
("gpt-4o", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_supports_response_schema(model, expected_bool):
|
||||||
|
"""
|
||||||
|
Unit tests for 'supports_response_schema' helper function.
|
||||||
|
|
||||||
|
Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models
|
||||||
|
Should be false otherwise
|
||||||
|
"""
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
response = supports_response_schema(model=model, custom_llm_provider=None)
|
||||||
|
|
||||||
|
assert expected_bool == response
|
||||||
|
|
|
@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
]
|
]
|
||||||
supported_openai_params: Required[Optional[List[str]]]
|
supported_openai_params: Required[Optional[List[str]]]
|
||||||
supports_system_messages: Optional[bool]
|
supports_system_messages: Optional[bool]
|
||||||
|
supports_response_schema: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
|
172
litellm/utils.py
172
litellm/utils.py
|
@ -1847,9 +1847,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model (str): The model name to be checked.
|
model (str): The model name to be checked.
|
||||||
|
custom_llm_provider (str): The provider to be checked.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the model supports function calling, False otherwise.
|
bool: True if the model supports system messages, False otherwise.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
@ -1867,6 +1868,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model + provider supports 'response_schema' as a param.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The model name to be checked.
|
||||||
|
custom_llm_provider (str): The provider to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports response_schema, False otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
## GET LLM PROVIDER ##
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_llm_provider == "predibase": # predibase supports this globally
|
||||||
|
return True
|
||||||
|
|
||||||
|
## GET MODEL INFO
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info.get("supports_response_schema", False) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def supports_function_calling(model: str) -> bool:
|
def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
@ -4434,8 +4472,7 @@ def get_max_tokens(model: str) -> Optional[int]:
|
||||||
|
|
||||||
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Get a dict for the maximum tokens (context window),
|
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
|
||||||
input_cost_per_token, output_cost_per_token for a given model.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- model (str): The name of the model.
|
- model (str): The name of the model.
|
||||||
|
@ -4520,6 +4557,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
mode="chat",
|
mode="chat",
|
||||||
supported_openai_params=supported_openai_params,
|
supported_openai_params=supported_openai_params,
|
||||||
supports_system_messages=None,
|
supports_system_messages=None,
|
||||||
|
supports_response_schema=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
|
@ -4541,36 +4579,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return ModelInfo(
|
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
|
||||||
input_cost_per_character=_model_info.get(
|
|
||||||
"input_cost_per_character", None
|
|
||||||
),
|
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
|
||||||
output_cost_per_character=_model_info.get(
|
|
||||||
"output_cost_per_character", None
|
|
||||||
),
|
|
||||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_character_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
litellm_provider=_model_info.get(
|
|
||||||
"litellm_provider", custom_llm_provider
|
|
||||||
),
|
|
||||||
mode=_model_info.get("mode"),
|
|
||||||
supported_openai_params=supported_openai_params,
|
|
||||||
supports_system_messages=_model_info.get(
|
|
||||||
"supports_system_messages", None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif model in litellm.model_cost:
|
elif model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[model]
|
_model_info = litellm.model_cost[model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4584,36 +4592,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return ModelInfo(
|
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
|
||||||
input_cost_per_character=_model_info.get(
|
|
||||||
"input_cost_per_character", None
|
|
||||||
),
|
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
|
||||||
output_cost_per_character=_model_info.get(
|
|
||||||
"output_cost_per_character", None
|
|
||||||
),
|
|
||||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_character_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
litellm_provider=_model_info.get(
|
|
||||||
"litellm_provider", custom_llm_provider
|
|
||||||
),
|
|
||||||
mode=_model_info.get("mode"),
|
|
||||||
supported_openai_params=supported_openai_params,
|
|
||||||
supports_system_messages=_model_info.get(
|
|
||||||
"supports_system_messages", None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif split_model in litellm.model_cost:
|
elif split_model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[split_model]
|
_model_info = litellm.model_cost[split_model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4627,40 +4605,48 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return ModelInfo(
|
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
|
||||||
input_cost_per_character=_model_info.get(
|
|
||||||
"input_cost_per_character", None
|
|
||||||
),
|
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
|
||||||
output_cost_per_character=_model_info.get(
|
|
||||||
"output_cost_per_character", None
|
|
||||||
),
|
|
||||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_character_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
litellm_provider=_model_info.get(
|
|
||||||
"litellm_provider", custom_llm_provider
|
|
||||||
),
|
|
||||||
mode=_model_info.get("mode"),
|
|
||||||
supported_openai_params=supported_openai_params,
|
|
||||||
supports_system_messages=_model_info.get(
|
|
||||||
"supports_system_messages", None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## PROVIDER-SPECIFIC INFORMATION
|
||||||
|
if custom_llm_provider == "predibase":
|
||||||
|
_model_info["supports_response_schema"] = True
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
input_cost_per_character=_model_info.get(
|
||||||
|
"input_cost_per_character", None
|
||||||
|
),
|
||||||
|
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"input_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
output_cost_per_character=_model_info.get(
|
||||||
|
"output_cost_per_character", None
|
||||||
|
),
|
||||||
|
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"output_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||||
|
"output_cost_per_character_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
supports_response_schema=_model_info.get(
|
||||||
|
"supports_response_schema", None
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
|
|
|
@ -1486,6 +1486,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-001": {
|
"gemini-1.5-pro-001": {
|
||||||
|
@ -1511,6 +1512,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0514": {
|
"gemini-1.5-pro-preview-0514": {
|
||||||
|
@ -2007,6 +2009,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini/gemini-1.5-pro-latest": {
|
"gemini/gemini-1.5-pro-latest": {
|
||||||
|
@ -2023,6 +2026,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://ai.google.dev/models/gemini"
|
"source": "https://ai.google.dev/models/gemini"
|
||||||
},
|
},
|
||||||
"gemini/gemini-pro-vision": {
|
"gemini/gemini-pro-vision": {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue