diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 92f97ebe54..49f2f0c286 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1486,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1511,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -2007,6 +2009,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -2023,6 +2026,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 78b64270c6..8225b309dc 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -663,3 +663,29 @@ def test_convert_model_response_object(): e.message == '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}' ) + + +@pytest.mark.parametrize( + "model, expected_bool", + [ + ("vertex_ai/gemini-1.5-pro", True), + ("gemini/gemini-1.5-pro", True), + ("predibase/llama3-8b-instruct", True), + ("gpt-4o", False), + ], +) +def test_supports_response_schema(model, expected_bool): + """ + Unit tests for 'supports_response_schema' helper function. + + Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models + Should be false otherwise + """ + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + from litellm.utils import supports_response_schema + + response = supports_response_schema(model=model, custom_llm_provider=None) + + assert expected_bool == response diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a63e34738a..51ce086711 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False): ] supported_openai_params: Required[Optional[List[str]]] supports_system_messages: Optional[bool] + supports_response_schema: Optional[bool] class GenericStreamingChunk(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index dc2bcb25aa..227274d3a9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1847,9 +1847,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> Parameters: model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. Returns: - bool: True if the model supports function calling, False otherwise. + bool: True if the model supports system messages, False otherwise. Raises: Exception: If the given model is not found in model_prices_and_context_window.json. @@ -1867,6 +1868,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> ) +def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the given model + provider supports 'response_schema' as a param. + + Parameters: + model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. + + Returns: + bool: True if the model supports response_schema, False otherwise. + + Raises: + Exception: If the given model is not found in model_prices_and_context_window.json. + """ + try: + ## GET LLM PROVIDER ## + model, custom_llm_provider, _, _ = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) + + if custom_llm_provider == "predibase": # predibase supports this globally + return True + + ## GET MODEL INFO + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + + if model_info.get("supports_response_schema", False) is True: + return True + return False + except Exception: + raise Exception( + f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ) + + def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -4434,8 +4472,7 @@ def get_max_tokens(model: str) -> Optional[int]: def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: """ - Get a dict for the maximum tokens (context window), - input_cost_per_token, output_cost_per_token for a given model. + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: - model (str): The name of the model. @@ -4520,6 +4557,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod mode="chat", supported_openai_params=supported_openai_params, supports_system_messages=None, + supports_response_schema=None, ) else: """ @@ -4541,36 +4579,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif model in litellm.model_cost: _model_info = litellm.model_cost[model] _model_info["supported_openai_params"] = supported_openai_params @@ -4584,36 +4592,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] _model_info["supported_openai_params"] = supported_openai_params @@ -4627,40 +4605,48 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) else: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" ) + + ## PROVIDER-SPECIFIC INFORMATION + if custom_llm_provider == "predibase": + _model_info["supports_response_schema"] = True + + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_character=_model_info.get( + "input_cost_per_character", None + ), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_character=_model_info.get( + "output_cost_per_character", None + ), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + output_cost_per_character_above_128k_tokens=_model_info.get( + "output_cost_per_character_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + supports_response_schema=_model_info.get( + "supports_response_schema", None + ), + ) except Exception: raise Exception( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 92f97ebe54..49f2f0c286 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1486,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1511,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -2007,6 +2009,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -2023,6 +2026,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": {