diff --git a/litellm/litellm_core_utils/json_validation_rule.py b/litellm/litellm_core_utils/json_validation_rule.py index f19144aaf..0f37e6737 100644 --- a/litellm/litellm_core_utils/json_validation_rule.py +++ b/litellm/litellm_core_utils/json_validation_rule.py @@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str): from litellm import JSONSchemaValidationError - response_dict = json.loads(response) + try: + response_dict = json.loads(response) + except json.JSONDecodeError: + raise JSONSchemaValidationError( + model="", llm_provider="", raw_response=response, schema=response + ) try: validate(response_dict, schema=schema) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index af5ccf828..b666d9494 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -16,6 +16,7 @@ from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, + HTTPHandler, _get_async_httpx_client, _get_httpx_client, ) @@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def process_response( + def _process_response( self, model: str, response: Union[requests.Response, httpx.Response], @@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM): messages: List, print_verbose, encoding, + json_mode: bool, ) -> ModelResponse: ## LOGGING logging_obj.post_call( @@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM): ) else: text_content = "" - tool_calls = [] - for content in completion_response["content"]: + tool_calls: List[ChatCompletionToolCallChunk] = [] + for idx, content in enumerate(completion_response["content"]): if content["type"] == "text": text_content += content["text"] ## TOOL CALLING elif content["type"] == "tool_use": tool_calls.append( - { - "id": content["id"], - "type": "function", - "function": { - "name": content["name"], - "arguments": json.dumps(content["input"]), - }, - } + ChatCompletionToolCallChunk( + id=content["id"], + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=content["name"], + arguments=json.dumps(content["input"]), + ), + index=idx, + ) ) _message = litellm.Message( tool_calls=tool_calls, content=text_content or None, ) + + ## HANDLE JSON MODE - anthropic returns single function call + if json_mode and len(tool_calls) == 1: + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + if json_mode_content_str is not None: + args = json.loads(json_mode_content_str) + values: Optional[dict] = args.get("values") + if values is not None: + _message = litellm.Message(content=json.dumps(values)) + completion_response["stop_reason"] = "stop" model_response.choices[0].message = _message # type: ignore model_response._hidden_params["original_response"] = completion_response[ "content" @@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM): _is_function_call, data: dict, optional_params: dict, + json_mode: bool, litellm_params=None, logger_fn=None, headers={}, + client=None, ) -> Union[ModelResponse, CustomStreamWrapper]: async_handler = _get_async_httpx_client() @@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM): ) raise e - return self.process_response( + return self._process_response( model=model, response=response, model_response=model_response, @@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM): print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, + json_mode=json_mode, ) def completion( @@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM): api_key, logging_obj, optional_params: dict, + timeout: Union[float, httpx.Timeout], acompletion=None, litellm_params=None, logger_fn=None, headers={}, + client=None, ): headers = validate_environment(api_key, headers, model) _is_function_call = False @@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM): anthropic_tools = [] for tool in optional_params["tools"]: - new_tool = tool["function"] - new_tool["input_schema"] = new_tool.pop("parameters") # rename key - anthropic_tools.append(new_tool) + if "input_schema" in tool: # assume in anthropic format + anthropic_tools.append(tool) + else: # assume openai tool call + new_tool = tool["function"] + new_tool["input_schema"] = new_tool.pop("parameters") # rename key + anthropic_tools.append(new_tool) optional_params["tools"] = anthropic_tools stream = optional_params.pop("stream", None) is_vertex_request: bool = optional_params.pop("is_vertex_request", False) + json_mode: bool = optional_params.pop("json_mode", False) data = { "messages": messages, @@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM): }, ) print_verbose(f"_is_function_call: {_is_function_call}") - if acompletion == True: + if acompletion is True: if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) @@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM): litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, + client=client, + json_mode=json_mode, ) else: ## COMPLETION CALL + if client is None or isinstance(client, AsyncHTTPHandler): + client = HTTPHandler(timeout=timeout) # type: ignore + else: + client = client if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) print_verbose("makes anthropic streaming POST request") data["stream"] = stream - response = requests.post( + response = client.post( api_base, headers=headers, data=json.dumps(data), @@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM): return streaming_response else: - response = requests.post( - api_base, headers=headers, data=json.dumps(data) - ) + response = client.post(api_base, headers=headers, data=json.dumps(data)) if response.status_code != 200: raise AnthropicError( status_code=response.status_code, message=response.text ) - return self.process_response( + return self._process_response( model=model, response=response, model_response=model_response, @@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM): print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, + json_mode=json_mode, ) def embedding(self): diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index b8362d5a5..900e7795f 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -7,7 +7,7 @@ import time import types import uuid from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -15,7 +15,14 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import AnthropicMessagesToolChoice +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -142,7 +149,27 @@ class VertexAIAnthropicConfig: if param == "top_p": optional_params["top_p"] = value if param == "response_format" and "response_schema" in value: - optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + _tool_choice = None + _tool_choice = {"name": "json_tool_call", "type": "tool"} + + _tool = AnthropicMessagesTool( + name="json_tool_call", + input_schema={ + "type": "object", + "properties": {"values": value["response_schema"]}, # type: ignore + }, + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + return optional_params @@ -222,6 +249,7 @@ def completion( optional_params: dict, custom_prompt_dict: dict, headers: Optional[dict], + timeout: Union[float, httpx.Timeout], vertex_project=None, vertex_location=None, vertex_credentials=None, @@ -301,6 +329,8 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, headers=vertex_headers, + client=client, + timeout=timeout, ) except Exception as e: diff --git a/litellm/main.py b/litellm/main.py index e01603b7e..69c845ad8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1528,6 +1528,8 @@ def completion( api_key=api_key, logging_obj=logging, headers=headers, + timeout=timeout, + client=client, ) if optional_params.get("stream", False) or acompletion == True: ## LOGGING @@ -2046,7 +2048,10 @@ def completion( acompletion=acompletion, headers=headers, custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, ) + else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 641c70ebc..1bd421f8d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,5 +1,13 @@ model_list: - - model_name: llama-3 + - model_name: bad-azure-model litellm_params: - model: gpt-4 - request_timeout: 1 + model: azure/chatgpt-v-2 + azure_ad_token: "" + api_base: os.environ/AZURE_API_BASE + + - model_name: good-openai-model + litellm_params: + model: gpt-3.5-turbo + +litellm_settings: + fallbacks: [{"bad-azure-model": ["good-openai-model"]}] diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index b8ba54cb4..3def5a1ec 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1128,6 +1128,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs): return mock_response +def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [ + { + "type": "tool_use", + "id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB", + "name": "json_tool_call", + "input": { + "values": [ + {"recipe_name": "Chocolate Chip Cookies"}, + {"recipe_name": "Oatmeal Raisin Cookies"}, + {"recipe_name": "Peanut Butter Cookies"}, + {"recipe_name": "Snickerdoodle Cookies"}, + {"recipe_name": "Sugar Cookies"}, + ] + }, + } + ], + "stop_reason": "tool_use", + "stop_sequence": None, + "usage": {"input_tokens": 368, "output_tokens": 118}, + } + + return mock_response + + def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 @@ -1183,11 +1216,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): return mock_response +def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [{"text": "Hi! My name is Claude.", "type": "text"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 368, "output_tokens": 118}, + } + return mock_response + + @pytest.mark.parametrize( "model, vertex_location, supports_response_schema", [ ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), + ("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False), ], ) @pytest.mark.parametrize( @@ -1231,12 +1282,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx( httpx_response = MagicMock() if invalid_response is True: - httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response + if "claude" in model: + httpx_response.side_effect = ( + vertex_httpx_mock_post_invalid_schema_response_anthropic + ) + else: + httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response else: - httpx_response.side_effect = vertex_httpx_mock_post_valid_response + if "claude" in model: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic + else: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response with patch.object(client, "post", new=httpx_response) as mock_call: + print("SENDING CLIENT POST={}".format(client.post)) try: - _ = completion( + resp = completion( model=model, messages=messages, response_format={ @@ -1247,30 +1307,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx( vertex_location=vertex_location, client=client, ) + print("Received={}".format(resp)) if invalid_response is True and enforce_validation is True: pytest.fail("Expected this to fail") except litellm.JSONSchemaValidationError as e: - if invalid_response is False and "claude-3" not in model: + if invalid_response is False: pytest.fail("Expected this to pass. Got={}".format(e)) mock_call.assert_called_once() - print(mock_call.call_args.kwargs) - print(mock_call.call_args.kwargs["json"]["generationConfig"]) + if "claude" not in model: + print(mock_call.call_args.kwargs) + print(mock_call.call_args.kwargs["json"]["generationConfig"]) - if supports_response_schema: - assert ( - "response_schema" - in mock_call.call_args.kwargs["json"]["generationConfig"] - ) - else: - assert ( - "response_schema" - not in mock_call.call_args.kwargs["json"]["generationConfig"] - ) - assert ( - "Use this JSON schema:" - in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] - ) + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][ + "text" + ] + ) @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",