From 4d963ab7893eeb1d58f783ffe649beee35c458ab Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Jul 2024 16:57:38 -0700 Subject: [PATCH 1/2] feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls allows passing response_schema for anthropic calls. supports schema validation. --- .../json_validation_rule.py | 7 +- litellm/llms/anthropic.py | 71 ++++++++---- litellm/llms/vertex_ai_anthropic.py | 36 +++++- litellm/main.py | 5 + litellm/proxy/_new_secret_config.yaml | 14 ++- .../tests/test_amazing_vertex_completion.py | 104 ++++++++++++++---- 6 files changed, 189 insertions(+), 48 deletions(-) diff --git a/litellm/litellm_core_utils/json_validation_rule.py b/litellm/litellm_core_utils/json_validation_rule.py index f19144aaf..0f37e6737 100644 --- a/litellm/litellm_core_utils/json_validation_rule.py +++ b/litellm/litellm_core_utils/json_validation_rule.py @@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str): from litellm import JSONSchemaValidationError - response_dict = json.loads(response) + try: + response_dict = json.loads(response) + except json.JSONDecodeError: + raise JSONSchemaValidationError( + model="", llm_provider="", raw_response=response, schema=response + ) try: validate(response_dict, schema=schema) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index af5ccf828..b666d9494 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -16,6 +16,7 @@ from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, + HTTPHandler, _get_async_httpx_client, _get_httpx_client, ) @@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def process_response( + def _process_response( self, model: str, response: Union[requests.Response, httpx.Response], @@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM): messages: List, print_verbose, encoding, + json_mode: bool, ) -> ModelResponse: ## LOGGING logging_obj.post_call( @@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM): ) else: text_content = "" - tool_calls = [] - for content in completion_response["content"]: + tool_calls: List[ChatCompletionToolCallChunk] = [] + for idx, content in enumerate(completion_response["content"]): if content["type"] == "text": text_content += content["text"] ## TOOL CALLING elif content["type"] == "tool_use": tool_calls.append( - { - "id": content["id"], - "type": "function", - "function": { - "name": content["name"], - "arguments": json.dumps(content["input"]), - }, - } + ChatCompletionToolCallChunk( + id=content["id"], + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=content["name"], + arguments=json.dumps(content["input"]), + ), + index=idx, + ) ) _message = litellm.Message( tool_calls=tool_calls, content=text_content or None, ) + + ## HANDLE JSON MODE - anthropic returns single function call + if json_mode and len(tool_calls) == 1: + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + if json_mode_content_str is not None: + args = json.loads(json_mode_content_str) + values: Optional[dict] = args.get("values") + if values is not None: + _message = litellm.Message(content=json.dumps(values)) + completion_response["stop_reason"] = "stop" model_response.choices[0].message = _message # type: ignore model_response._hidden_params["original_response"] = completion_response[ "content" @@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM): _is_function_call, data: dict, optional_params: dict, + json_mode: bool, litellm_params=None, logger_fn=None, headers={}, + client=None, ) -> Union[ModelResponse, CustomStreamWrapper]: async_handler = _get_async_httpx_client() @@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM): ) raise e - return self.process_response( + return self._process_response( model=model, response=response, model_response=model_response, @@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM): print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, + json_mode=json_mode, ) def completion( @@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM): api_key, logging_obj, optional_params: dict, + timeout: Union[float, httpx.Timeout], acompletion=None, litellm_params=None, logger_fn=None, headers={}, + client=None, ): headers = validate_environment(api_key, headers, model) _is_function_call = False @@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM): anthropic_tools = [] for tool in optional_params["tools"]: - new_tool = tool["function"] - new_tool["input_schema"] = new_tool.pop("parameters") # rename key - anthropic_tools.append(new_tool) + if "input_schema" in tool: # assume in anthropic format + anthropic_tools.append(tool) + else: # assume openai tool call + new_tool = tool["function"] + new_tool["input_schema"] = new_tool.pop("parameters") # rename key + anthropic_tools.append(new_tool) optional_params["tools"] = anthropic_tools stream = optional_params.pop("stream", None) is_vertex_request: bool = optional_params.pop("is_vertex_request", False) + json_mode: bool = optional_params.pop("json_mode", False) data = { "messages": messages, @@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM): }, ) print_verbose(f"_is_function_call: {_is_function_call}") - if acompletion == True: + if acompletion is True: if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) @@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM): litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, + client=client, + json_mode=json_mode, ) else: ## COMPLETION CALL + if client is None or isinstance(client, AsyncHTTPHandler): + client = HTTPHandler(timeout=timeout) # type: ignore + else: + client = client if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) print_verbose("makes anthropic streaming POST request") data["stream"] = stream - response = requests.post( + response = client.post( api_base, headers=headers, data=json.dumps(data), @@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM): return streaming_response else: - response = requests.post( - api_base, headers=headers, data=json.dumps(data) - ) + response = client.post(api_base, headers=headers, data=json.dumps(data)) if response.status_code != 200: raise AnthropicError( status_code=response.status_code, message=response.text ) - return self.process_response( + return self._process_response( model=model, response=response, model_response=model_response, @@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM): print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, + json_mode=json_mode, ) def embedding(self): diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index b8362d5a5..900e7795f 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -7,7 +7,7 @@ import time import types import uuid from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -15,7 +15,14 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import AnthropicMessagesToolChoice +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -142,7 +149,27 @@ class VertexAIAnthropicConfig: if param == "top_p": optional_params["top_p"] = value if param == "response_format" and "response_schema" in value: - optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + _tool_choice = None + _tool_choice = {"name": "json_tool_call", "type": "tool"} + + _tool = AnthropicMessagesTool( + name="json_tool_call", + input_schema={ + "type": "object", + "properties": {"values": value["response_schema"]}, # type: ignore + }, + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + return optional_params @@ -222,6 +249,7 @@ def completion( optional_params: dict, custom_prompt_dict: dict, headers: Optional[dict], + timeout: Union[float, httpx.Timeout], vertex_project=None, vertex_location=None, vertex_credentials=None, @@ -301,6 +329,8 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, headers=vertex_headers, + client=client, + timeout=timeout, ) except Exception as e: diff --git a/litellm/main.py b/litellm/main.py index e01603b7e..69c845ad8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1528,6 +1528,8 @@ def completion( api_key=api_key, logging_obj=logging, headers=headers, + timeout=timeout, + client=client, ) if optional_params.get("stream", False) or acompletion == True: ## LOGGING @@ -2046,7 +2048,10 @@ def completion( acompletion=acompletion, headers=headers, custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, ) + else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 711b516ab..99e8e41d7 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,5 +1,13 @@ model_list: - - model_name: llama-3 + - model_name: bad-azure-model litellm_params: - model: gpt-4 - request_timeout: 1 \ No newline at end of file + model: azure/chatgpt-v-2 + azure_ad_token: "" + api_base: os.environ/AZURE_API_BASE + + - model_name: good-openai-model + litellm_params: + model: gpt-3.5-turbo + +litellm_settings: + fallbacks: [{"bad-azure-model": ["good-openai-model"]}] \ No newline at end of file diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 6a381022e..4b3143453 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1150,6 +1150,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs): return mock_response +def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [ + { + "type": "tool_use", + "id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB", + "name": "json_tool_call", + "input": { + "values": [ + {"recipe_name": "Chocolate Chip Cookies"}, + {"recipe_name": "Oatmeal Raisin Cookies"}, + {"recipe_name": "Peanut Butter Cookies"}, + {"recipe_name": "Snickerdoodle Cookies"}, + {"recipe_name": "Sugar Cookies"}, + ] + }, + } + ], + "stop_reason": "tool_use", + "stop_sequence": None, + "usage": {"input_tokens": 368, "output_tokens": 118}, + } + + return mock_response + + def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 @@ -1205,11 +1238,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): return mock_response +def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [{"text": "Hi! My name is Claude.", "type": "text"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 368, "output_tokens": 118}, + } + return mock_response + + @pytest.mark.parametrize( "model, vertex_location, supports_response_schema", [ ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), + ("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False), ], ) @pytest.mark.parametrize( @@ -1253,12 +1304,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx( httpx_response = MagicMock() if invalid_response is True: - httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response + if "claude" in model: + httpx_response.side_effect = ( + vertex_httpx_mock_post_invalid_schema_response_anthropic + ) + else: + httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response else: - httpx_response.side_effect = vertex_httpx_mock_post_valid_response + if "claude" in model: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic + else: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response with patch.object(client, "post", new=httpx_response) as mock_call: + print("SENDING CLIENT POST={}".format(client.post)) try: - _ = completion( + resp = completion( model=model, messages=messages, response_format={ @@ -1269,30 +1329,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx( vertex_location=vertex_location, client=client, ) + print("Received={}".format(resp)) if invalid_response is True and enforce_validation is True: pytest.fail("Expected this to fail") except litellm.JSONSchemaValidationError as e: - if invalid_response is False and "claude-3" not in model: + if invalid_response is False: pytest.fail("Expected this to pass. Got={}".format(e)) mock_call.assert_called_once() - print(mock_call.call_args.kwargs) - print(mock_call.call_args.kwargs["json"]["generationConfig"]) + if "claude" not in model: + print(mock_call.call_args.kwargs) + print(mock_call.call_args.kwargs["json"]["generationConfig"]) - if supports_response_schema: - assert ( - "response_schema" - in mock_call.call_args.kwargs["json"]["generationConfig"] - ) - else: - assert ( - "response_schema" - not in mock_call.call_args.kwargs["json"]["generationConfig"] - ) - assert ( - "Use this JSON schema:" - in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] - ) + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][ + "text" + ] + ) @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", From 6d741a54241dd619ea54db4883785ed0360ee434 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Jul 2024 17:20:19 -0700 Subject: [PATCH 2/2] docs(json_mode.md): add json mode to docs --- docs/my-website/docs/completion/json_mode.md | 137 +++++++++++++++++++ docs/my-website/sidebars.js | 1 + 2 files changed, 138 insertions(+) create mode 100644 docs/my-website/docs/completion/json_mode.md diff --git a/docs/my-website/docs/completion/json_mode.md b/docs/my-website/docs/completion/json_mode.md new file mode 100644 index 000000000..0e7e64a8e --- /dev/null +++ b/docs/my-website/docs/completion/json_mode.md @@ -0,0 +1,137 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# JSON Mode + +## Quick Start + + + + +```python +from litellm import completion +import os + +os.environ["OPENAI_API_KEY"] = "" + +response = completion( + model="gpt-4o-mini", + response_format={ "type": "json_object" }, + messages=[ + {"role": "system", "content": "You are a helpful assistant designed to output JSON."}, + {"role": "user", "content": "Who won the world series in 2020?"} + ] +) +print(response.choices[0].message.content) +``` + + + +```bash +curl http://0.0.0.0:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $LITELLM_KEY" \ + -d '{ + "model": "gpt-4o-mini", + "response_format": { "type": "json_object" }, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant designed to output JSON." + }, + { + "role": "user", + "content": "Who won the world series in 2020?" + } + ] + }' +``` + + + +## Check Model Support + +Call `litellm.get_supported_openai_params` to check if a model/provider supports `response_format`. + +```python +from litellm import get_supported_openai_params + +params = get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock") + +assert "response_format" in params +``` + +## Validate JSON Schema + +For VertexAI models, LiteLLM supports passing the `response_schema` and validating the JSON output. + +This works across Gemini (`vertex_ai_beta/`) + Anthropic (`vertex_ai/`) models. + + + + + +```python +# !gcloud auth application-default login - run this to add vertex credentials to your env + +from litellm import completion + +messages = [{"role": "user", "content": "List 5 cookie recipes"}] + +response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, +} + +resp = completion( + model="vertex_ai_beta/gemini-1.5-pro", + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + "enforce_validation": True, # client-side json schema validation + }, + vertex_location="us-east5", +) + +print("Received={}".format(resp)) +``` + + + +```bash +curl http://0.0.0.0:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $LITELLM_API_KEY" \ + -d '{ + "model": "vertex_ai_beta/gemini-1.5-pro", + "messages": [{"role": "user", "content": "List 5 cookie recipes"}] + "response_format": { + "type": "json_object", + "enforce_validation: true, + "response_schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + }, + }' +``` + + + \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index d2179cafc..cc23092e6 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -91,6 +91,7 @@ const sidebars = { items: [ "completion/input", "completion/provider_specific_params", + "completion/json_mode", "completion/drop_params", "completion/prompt_formatting", "completion/output",