diff --git a/.circleci/config.yml b/.circleci/config.yml index da0738fb7..40d498d6e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -66,7 +66,7 @@ jobs: pip install "pydantic==2.7.1" pip install "diskcache==5.6.1" pip install "Pillow==10.3.0" - pip install "ijson==3.2.3" + pip install "jsonschema==4.22.0" - save_cache: paths: - ./venv @@ -128,7 +128,7 @@ jobs: pip install jinja2 pip install tokenizers pip install openai - pip install ijson + pip install jsonschema - run: name: Run tests command: | @@ -183,7 +183,7 @@ jobs: pip install numpydoc pip install prisma pip install fastapi - pip install ijson + pip install jsonschema pip install "httpx==0.24.1" pip install "gunicorn==21.2.0" pip install "anyio==3.7.1" diff --git a/litellm/__init__.py b/litellm/__init__.py index 5bd5d1a16..0fa822a98 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -749,6 +749,7 @@ from .utils import ( create_pretrained_tokenizer, create_tokenizer, supports_function_calling, + supports_response_schema, supports_parallel_function_calling, supports_vision, supports_system_messages, @@ -852,6 +853,7 @@ from .exceptions import ( APIResponseValidationError, UnprocessableEntityError, InternalServerError, + JSONSchemaValidationError, LITELLM_EXCEPTION_TYPES, ) from .budget_manager import BudgetManager diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 98b519278..d85510b1d 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore message, llm_provider, model, - request: httpx.Request, + request: Optional[httpx.Request] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if request is None: + request = httpx.Request(method="POST", url="https://api.openai.com/v1") super().__init__(self.message, request=request, body=None) # type: ignore def __str__(self): @@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore self.llm_provider = "openai" +class JSONSchemaValidationError(APIError): + def __init__( + self, model: str, llm_provider: str, raw_response: str, schema: str + ) -> None: + self.raw_response = raw_response + self.schema = schema + self.model = model + message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format( + model, raw_response, schema + ) + self.message = message + super().__init__( + model=model, message=message, llm_provider=llm_provider, status_code=500 + ) + + LITELLM_EXCEPTION_TYPES = [ AuthenticationError, NotFoundError, @@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [ APIResponseValidationError, OpenAIError, InternalServerError, + JSONSchemaValidationError, ] diff --git a/litellm/litellm_core_utils/json_validation_rule.py b/litellm/litellm_core_utils/json_validation_rule.py new file mode 100644 index 000000000..f19144aaf --- /dev/null +++ b/litellm/litellm_core_utils/json_validation_rule.py @@ -0,0 +1,23 @@ +import json + + +def validate_schema(schema: dict, response: str): + """ + Validate if the returned json response follows the schema. + + Params: + - schema - dict: JSON schema + - response - str: Received json response as string. + """ + from jsonschema import ValidationError, validate + + from litellm import JSONSchemaValidationError + + response_dict = json.loads(response) + + try: + validate(response_dict, schema=schema) + except ValidationError: + raise JSONSchemaValidationError( + model="", llm_provider="", raw_response=response, schema=json.dumps(schema) + ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index b35914584..87af2a6bd 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list): return messages +def response_schema_prompt(model: str, response_schema: dict) -> str: + """ + Decides if a user-defined custom prompt or default needs to be used + + Returns the prompt str that's passed to the model as a user message + """ + custom_prompt_details: Optional[dict] = None + response_schema_as_message = [ + {"role": "user", "content": "{}".format(response_schema)} + ] + if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: + + custom_prompt_details = litellm.custom_prompt_dict[ + f"{model}/response_schema_prompt" + ] # allow user to define custom response schema prompt by model + elif "response_schema_prompt" in litellm.custom_prompt_dict: + custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"] + + if custom_prompt_details is not None: + return custom_prompt( + role_dict=custom_prompt_details["roles"], + initial_prompt_value=custom_prompt_details["initial_prompt_value"], + final_prompt_value=custom_prompt_details["final_prompt_value"], + messages=response_schema_as_message, + ) + else: + return default_response_schema_prompt(response_schema=response_schema) + + +def default_response_schema_prompt(response_schema: dict) -> str: + """ + Used if provider/model doesn't support 'response_schema' param. + + This is the default prompt. Allow user to override this with a custom_prompt. + """ + prompt_str = """Use this JSON schema: + ```json + {} + ```""".format( + response_schema + ) + return prompt_str + + # Custom prompt template def custom_prompt( role_dict: dict, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 4a4abaef4..c1e628d17 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ import requests # type: ignore from pydantic import BaseModel import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.prompt_templates.factory import ( convert_to_anthropic_image_obj, @@ -328,80 +329,86 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: contents: List[ContentType] = [] msg_i = 0 - while msg_i < len(messages): - user_content: List[PartType] = [] - init_msg_i = msg_i - ## MERGE CONSECUTIVE USER CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: - if isinstance(messages[msg_i]["content"], list): - _parts: List[PartType] = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text" and len(element["text"]) > 0: - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - user_content.extend(_parts) - elif ( - isinstance(messages[msg_i]["content"], str) - and len(messages[msg_i]["content"]) > 0 + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types ): - _part = PartType(text=messages[msg_i]["content"]) - user_content.append(_part) + if isinstance(messages[msg_i]["content"], list): + _parts: List[PartType] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text" and len(element["text"]) > 0: + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + user_content.extend(_parts) + elif ( + isinstance(messages[msg_i]["content"], str) + and len(messages[msg_i]["content"]) > 0 + ): + _part = PartType(text=messages[msg_i]["content"]) + user_content.append(_part) - msg_i += 1 + msg_i += 1 - if user_content: - contents.append(ContentType(role="user", parts=user_content)) - assistant_content = [] - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if isinstance(messages[msg_i]["content"], list): - _parts = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text": - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - assistant_content.extend(_parts) - elif messages[msg_i].get( - "tool_calls", [] - ): # support assistant tool invoke conversion - assistant_content.extend( - convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) + if user_content: + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + _parts = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + assistant_content.extend(_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke( + messages[msg_i]["tool_calls"] + ) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(PartType(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + _part = convert_to_gemini_tool_call_result(messages[msg_i]) + contents.append(ContentType(parts=[_part])) # type: ignore + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) ) - else: - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content.append(PartType(text=assistant_text)) - - msg_i += 1 - - if assistant_content: - contents.append(ContentType(role="model", parts=assistant_content)) - - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - _part = convert_to_gemini_tool_call_result(messages[msg_i]) - contents.append(ContentType(parts=[_part])) # type: ignore - msg_i += 1 - if msg_i == init_msg_i: # prevent infinite loops - raise Exception( - "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( - messages[msg_i] - ) - ) - - return contents + return contents + except Exception as e: + raise e def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 940016ecb..9e361d3cc 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore -import ijson import requests # type: ignore import litellm @@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.prompt_templates.factory import convert_url_to_base64 +from litellm.llms.prompt_templates.factory import ( + convert_url_to_base64, + response_schema_prompt, +) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.llms.openai import ( ChatCompletionResponseMessage, @@ -1011,35 +1013,53 @@ class VertexLLM(BaseLLM): if len(system_prompt_indices) > 0: for idx in reversed(system_prompt_indices): messages.pop(idx) - content = _gemini_convert_messages_with_history(messages=messages) - tools: Optional[Tools] = optional_params.pop("tools", None) - tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) - safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( - "safety_settings", None - ) # type: ignore - generation_config: Optional[GenerationConfig] = GenerationConfig( - **optional_params - ) - data = RequestBody(contents=content) - if len(system_content_blocks) > 0: - system_instructions = SystemInstructions(parts=system_content_blocks) - data["system_instruction"] = system_instructions - if tools is not None: - data["tools"] = tools - if tool_choice is not None: - data["toolConfig"] = tool_choice - if safety_settings is not None: - data["safetySettings"] = safety_settings - if generation_config is not None: - data["generationConfig"] = generation_config - headers = { - "Content-Type": "application/json", - } - if auth_header is not None: - headers["Authorization"] = f"Bearer {auth_header}" - if extra_headers is not None: - headers.update(extra_headers) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + optional_params.pop("response_schema") + + try: + content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + generation_config: Optional[GenerationConfig] = GenerationConfig( + **optional_params + ) + data = RequestBody(contents=content) + if len(system_content_blocks) > 0: + system_instructions = SystemInstructions(parts=system_content_blocks) + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + except Exception as e: + raise e ## LOGGING logging_obj.pre_call( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 49f2f0c28..7f08b9eb1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1538,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1563,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1586,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index e6f2634f4..3a48bcb6c 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -880,10 +880,141 @@ Using this JSON schema: mock_call.assert_called_once() -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +def vertex_httpx_mock_post_valid_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n' + } + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + {"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'} + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +@pytest.mark.parametrize( + "model, supports_response_schema", + [ + ("vertex_ai_beta/gemini-1.5-pro-001", True), + ("vertex_ai_beta/gemini-1.5-flash", False), + ], +) +@pytest.mark.parametrize( + "invalid_response", + [True, False], +) +@pytest.mark.parametrize( + "enforce_validation", + [True, False], +) @pytest.mark.asyncio -async def test_gemini_pro_json_schema_httpx(provider): +async def test_gemini_pro_json_schema_args_sent_httpx( + model, supports_response_schema, invalid_response, enforce_validation +): load_vertex_ai_credentials() + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.set_verbose = True messages = [{"role": "user", "content": "List 5 cookie recipes"}] from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -903,26 +1034,47 @@ async def test_gemini_pro_json_schema_httpx(provider): client = HTTPHandler() - with patch.object(client, "post", new=MagicMock()) as mock_call: + httpx_response = MagicMock() + if invalid_response is True: + httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response + else: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response + with patch.object(client, "post", new=httpx_response) as mock_call: try: - response = completion( - model="vertex_ai_beta/gemini-1.5-pro-001", + _ = completion( + model=model, messages=messages, response_format={ "type": "json_object", "response_schema": response_schema, + "enforce_validation": enforce_validation, }, client=client, ) - except Exception as e: - pass + if invalid_response is True and enforce_validation is True: + pytest.fail("Expected this to fail") + except litellm.JSONSchemaValidationError as e: + if invalid_response is False: + pytest.fail("Expected this to pass. Got={}".format(e)) mock_call.assert_called_once() print(mock_call.call_args.kwargs) print(mock_call.call_args.kwargs["json"]["generationConfig"]) - assert ( - "response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"] - ) + + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] + ) @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @@ -959,48 +1111,6 @@ async def test_gemini_pro_httpx_custom_api_base(provider): assert "hello" in mock_call.call_args.kwargs["headers"] -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", -@pytest.mark.asyncio -async def test_gemini_pro_httpx_custom_api_base_streaming_real_call( - provider, sync_mode -): - load_vertex_ai_credentials() - import random - - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "Hey, how's it going?", - } - ] - - vertex_region = random.sample(["asia-southeast1", "us-central1"], k=1)[0] - if sync_mode is True: - response = completion( - model="vertex_ai_beta/gemini-1.5-flash", - messages=messages, - api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash", - stream=True, - vertex_region=vertex_region, - ) - - for chunk in response: - print(chunk) - else: - response = await litellm.acompletion( - model="vertex_ai_beta/gemini-1.5-flash", - messages=messages, - api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash", - stream=True, - vertex_region=vertex_region, - ) - - async for chunk in response: - print(chunk) - - @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") @pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("provider", ["vertex_ai"]) diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 72d4d7b1b..e287946ae 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -61,7 +61,6 @@ async def test_token_single_public_key(): import jwt jwt_handler = JWTHandler() - backend_keys = { "keys": [ { diff --git a/litellm/utils.py b/litellm/utils.py index 08a5eb40d..4b80d203b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -48,6 +48,7 @@ from tokenizers import Tokenizer import litellm import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm.litellm_core_utils +import litellm.litellm_core_utils.json_validation_rule from litellm.caching import DualCache from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.exception_mapping_utils import get_error_message @@ -580,7 +581,7 @@ def client(original_function): else: return False - def post_call_processing(original_response, model): + def post_call_processing(original_response, model, optional_params: Optional[dict]): try: if original_response is None: pass @@ -595,11 +596,47 @@ def client(original_function): pass else: if isinstance(original_response, ModelResponse): - model_response = original_response.choices[ + model_response: Optional[str] = original_response.choices[ 0 - ].message.content - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) + ].message.content # type: ignore + if model_response is not None: + ### POST-CALL RULES ### + rules_obj.post_call_rules( + input=model_response, model=model + ) + ### JSON SCHEMA VALIDATION ### + if ( + optional_params is not None + and "response_format" in optional_params + and isinstance( + optional_params["response_format"], dict + ) + and "type" in optional_params["response_format"] + and optional_params["response_format"]["type"] + == "json_object" + and "response_schema" + in optional_params["response_format"] + and isinstance( + optional_params["response_format"][ + "response_schema" + ], + dict, + ) + and "enforce_validation" + in optional_params["response_format"] + and optional_params["response_format"][ + "enforce_validation" + ] + is True + ): + # schema given, json response expected, and validation enforced + litellm.litellm_core_utils.json_validation_rule.validate_schema( + schema=optional_params["response_format"][ + "response_schema" + ], + response=model_response, + ) + except Exception as e: raise e @@ -868,7 +905,11 @@ def client(original_function): return result ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model or None) + post_call_processing( + original_response=result, + model=model or None, + optional_params=kwargs, + ) # [OPTIONAL] ADD TO CACHE if ( @@ -1317,7 +1358,9 @@ def client(original_function): ).total_seconds() * 1000 # return response latency in ms like openai ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model) + post_call_processing( + original_response=result, model=model, optional_params=kwargs + ) # [OPTIONAL] ADD TO CACHE if ( @@ -1880,8 +1923,7 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> Returns: bool: True if the model supports response_schema, False otherwise. - Raises: - Exception: If the given model is not found in model_prices_and_context_window.json. + Does not raise error. Defaults to 'False'. Outputs logging.error. """ try: ## GET LLM PROVIDER ## @@ -1901,9 +1943,10 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> return True return False except Exception: - raise Exception( + verbose_logger.error( f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." ) + return False def supports_function_calling(model: str) -> bool: diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 49f2f0c28..7f08b9eb1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1538,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1563,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1586,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { diff --git a/pyproject.toml b/pyproject.toml index 95aa18c08..2519c167f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" pydantic = "^2.0.0" -ijson = "*" +jsonschema = "^4.22.0" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^22.0.0", optional = true} diff --git a/requirements.txt b/requirements.txt index 00d3802da..e71ab450b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls aioboto3==12.3.0 # for async sagemaker calls tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.7.1 # proxy + openai req. -ijson==3.2.3 # for google ai studio streaming +jsonschema==4.22.0 # validating json schema #### \ No newline at end of file