diff --git a/litellm/__init__.py b/litellm/__init__.py index a8d9a80a2..73c382516 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -749,6 +749,7 @@ from .utils import ( create_pretrained_tokenizer, create_tokenizer, supports_function_calling, + supports_response_schema, supports_parallel_function_calling, supports_vision, supports_system_messages, diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index b35914584..87af2a6bd 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list): return messages +def response_schema_prompt(model: str, response_schema: dict) -> str: + """ + Decides if a user-defined custom prompt or default needs to be used + + Returns the prompt str that's passed to the model as a user message + """ + custom_prompt_details: Optional[dict] = None + response_schema_as_message = [ + {"role": "user", "content": "{}".format(response_schema)} + ] + if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: + + custom_prompt_details = litellm.custom_prompt_dict[ + f"{model}/response_schema_prompt" + ] # allow user to define custom response schema prompt by model + elif "response_schema_prompt" in litellm.custom_prompt_dict: + custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"] + + if custom_prompt_details is not None: + return custom_prompt( + role_dict=custom_prompt_details["roles"], + initial_prompt_value=custom_prompt_details["initial_prompt_value"], + final_prompt_value=custom_prompt_details["final_prompt_value"], + messages=response_schema_as_message, + ) + else: + return default_response_schema_prompt(response_schema=response_schema) + + +def default_response_schema_prompt(response_schema: dict) -> str: + """ + Used if provider/model doesn't support 'response_schema' param. + + This is the default prompt. Allow user to override this with a custom_prompt. + """ + prompt_str = """Use this JSON schema: + ```json + {} + ```""".format( + response_schema + ) + return prompt_str + + # Custom prompt template def custom_prompt( role_dict: dict, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 4a4abaef4..c1e628d17 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ import requests # type: ignore from pydantic import BaseModel import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.prompt_templates.factory import ( convert_to_anthropic_image_obj, @@ -328,80 +329,86 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: contents: List[ContentType] = [] msg_i = 0 - while msg_i < len(messages): - user_content: List[PartType] = [] - init_msg_i = msg_i - ## MERGE CONSECUTIVE USER CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: - if isinstance(messages[msg_i]["content"], list): - _parts: List[PartType] = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text" and len(element["text"]) > 0: - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - user_content.extend(_parts) - elif ( - isinstance(messages[msg_i]["content"], str) - and len(messages[msg_i]["content"]) > 0 + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types ): - _part = PartType(text=messages[msg_i]["content"]) - user_content.append(_part) + if isinstance(messages[msg_i]["content"], list): + _parts: List[PartType] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text" and len(element["text"]) > 0: + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + user_content.extend(_parts) + elif ( + isinstance(messages[msg_i]["content"], str) + and len(messages[msg_i]["content"]) > 0 + ): + _part = PartType(text=messages[msg_i]["content"]) + user_content.append(_part) - msg_i += 1 + msg_i += 1 - if user_content: - contents.append(ContentType(role="user", parts=user_content)) - assistant_content = [] - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if isinstance(messages[msg_i]["content"], list): - _parts = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text": - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - assistant_content.extend(_parts) - elif messages[msg_i].get( - "tool_calls", [] - ): # support assistant tool invoke conversion - assistant_content.extend( - convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) + if user_content: + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + _parts = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + assistant_content.extend(_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke( + messages[msg_i]["tool_calls"] + ) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(PartType(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + _part = convert_to_gemini_tool_call_result(messages[msg_i]) + contents.append(ContentType(parts=[_part])) # type: ignore + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) ) - else: - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content.append(PartType(text=assistant_text)) - - msg_i += 1 - - if assistant_content: - contents.append(ContentType(role="model", parts=assistant_content)) - - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - _part = convert_to_gemini_tool_call_result(messages[msg_i]) - contents.append(ContentType(parts=[_part])) # type: ignore - msg_i += 1 - if msg_i == init_msg_i: # prevent infinite loops - raise Exception( - "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( - messages[msg_i] - ) - ) - - return contents + return contents + except Exception as e: + raise e def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 940016ecb..91a2b0276 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -21,7 +21,10 @@ import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.prompt_templates.factory import convert_url_to_base64 +from litellm.llms.prompt_templates.factory import ( + convert_url_to_base64, + response_schema_prompt, +) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.llms.openai import ( ChatCompletionResponseMessage, @@ -1011,35 +1014,53 @@ class VertexLLM(BaseLLM): if len(system_prompt_indices) > 0: for idx in reversed(system_prompt_indices): messages.pop(idx) - content = _gemini_convert_messages_with_history(messages=messages) - tools: Optional[Tools] = optional_params.pop("tools", None) - tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) - safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( - "safety_settings", None - ) # type: ignore - generation_config: Optional[GenerationConfig] = GenerationConfig( - **optional_params - ) - data = RequestBody(contents=content) - if len(system_content_blocks) > 0: - system_instructions = SystemInstructions(parts=system_content_blocks) - data["system_instruction"] = system_instructions - if tools is not None: - data["tools"] = tools - if tool_choice is not None: - data["toolConfig"] = tool_choice - if safety_settings is not None: - data["safetySettings"] = safety_settings - if generation_config is not None: - data["generationConfig"] = generation_config - headers = { - "Content-Type": "application/json", - } - if auth_header is not None: - headers["Authorization"] = f"Bearer {auth_header}" - if extra_headers is not None: - headers.update(extra_headers) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + optional_params.pop("response_schema") + + try: + content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + generation_config: Optional[GenerationConfig] = GenerationConfig( + **optional_params + ) + data = RequestBody(contents=content) + if len(system_content_blocks) > 0: + system_instructions = SystemInstructions(parts=system_content_blocks) + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + except Exception as e: + raise e ## LOGGING logging_obj.pre_call( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 49f2f0c28..7f08b9eb1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1538,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1563,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1586,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index e6f2634f4..4cb79affa 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -880,10 +880,19 @@ Using this JSON schema: mock_call.assert_called_once() -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize( + "model, supports_response_schema", + [ + ("vertex_ai_beta/gemini-1.5-pro-001", True), + ("vertex_ai_beta/gemini-1.5-flash", False), + ], +) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_json_schema_httpx(provider): +async def test_gemini_pro_json_schema_httpx(model, supports_response_schema): load_vertex_ai_credentials() + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.set_verbose = True messages = [{"role": "user", "content": "List 5 cookie recipes"}] from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -905,8 +914,8 @@ async def test_gemini_pro_json_schema_httpx(provider): with patch.object(client, "post", new=MagicMock()) as mock_call: try: - response = completion( - model="vertex_ai_beta/gemini-1.5-pro-001", + _ = completion( + model=model, messages=messages, response_format={ "type": "json_object", @@ -914,15 +923,27 @@ async def test_gemini_pro_json_schema_httpx(provider): }, client=client, ) - except Exception as e: + except Exception: pass mock_call.assert_called_once() print(mock_call.call_args.kwargs) print(mock_call.call_args.kwargs["json"]["generationConfig"]) - assert ( - "response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"] - ) + + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] + ) @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", diff --git a/litellm/utils.py b/litellm/utils.py index 227274d3a..91cf75424 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1879,8 +1879,7 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> Returns: bool: True if the model supports response_schema, False otherwise. - Raises: - Exception: If the given model is not found in model_prices_and_context_window.json. + Does not raise error. Defaults to 'False'. Outputs logging.error. """ try: ## GET LLM PROVIDER ## @@ -1900,9 +1899,10 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> return True return False except Exception: - raise Exception( + verbose_logger.error( f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." ) + return False def supports_function_calling(model: str) -> bool: diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 49f2f0c28..7f08b9eb1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1538,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1563,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1586,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": {