feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls

allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
Krrish Dholakia 2024-07-18 16:57:38 -07:00
parent f8bdfe7cc3
commit f2401d6d5e
6 changed files with 189 additions and 48 deletions

View file

@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str):
from litellm import JSONSchemaValidationError from litellm import JSONSchemaValidationError
response_dict = json.loads(response) try:
response_dict = json.loads(response)
except json.JSONDecodeError:
raise JSONSchemaValidationError(
model="", llm_provider="", raw_response=response, schema=response
)
try: try:
validate(response_dict, schema=schema) validate(response_dict, schema=schema)

View file

@ -16,6 +16,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client, _get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
) )
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_response( def _process_response(
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
messages: List, messages: List,
print_verbose, print_verbose,
encoding, encoding,
json_mode: bool,
) -> ModelResponse: ) -> ModelResponse:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
) )
else: else:
text_content = "" text_content = ""
tool_calls = [] tool_calls: List[ChatCompletionToolCallChunk] = []
for content in completion_response["content"]: for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text": if content["type"] == "text":
text_content += content["text"] text_content += content["text"]
## TOOL CALLING ## TOOL CALLING
elif content["type"] == "tool_use": elif content["type"] == "tool_use":
tool_calls.append( tool_calls.append(
{ ChatCompletionToolCallChunk(
"id": content["id"], id=content["id"],
"type": "function", type="function",
"function": { function=ChatCompletionToolCallFunctionChunk(
"name": content["name"], name=content["name"],
"arguments": json.dumps(content["input"]), arguments=json.dumps(content["input"]),
}, ),
} index=idx,
)
) )
_message = litellm.Message( _message = litellm.Message(
tool_calls=tool_calls, tool_calls=tool_calls,
content=text_content or None, content=text_content or None,
) )
## HANDLE JSON MODE - anthropic returns single function call
if json_mode and len(tool_calls) == 1:
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
completion_response["stop_reason"] = "stop"
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[ model_response._hidden_params["original_response"] = completion_response[
"content" "content"
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
_is_function_call, _is_function_call,
data: dict, data: dict,
optional_params: dict, optional_params: dict,
json_mode: bool,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client=None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client() async_handler = _get_async_httpx_client()
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
) )
raise e raise e
return self.process_response( return self._process_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
json_mode=json_mode,
) )
def completion( def completion(
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client=None,
): ):
headers = validate_environment(api_key, headers, model) headers = validate_environment(api_key, headers, model)
_is_function_call = False _is_function_call = False
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
anthropic_tools = [] anthropic_tools = []
for tool in optional_params["tools"]: for tool in optional_params["tools"]:
new_tool = tool["function"] if "input_schema" in tool: # assume in anthropic format
new_tool["input_schema"] = new_tool.pop("parameters") # rename key anthropic_tools.append(tool)
anthropic_tools.append(new_tool) else: # assume openai tool call
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False) is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
json_mode: bool = optional_params.pop("json_mode", False)
data = { data = {
"messages": messages, "messages": messages,
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
}, },
) )
print_verbose(f"_is_function_call: {_is_function_call}") print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True: if acompletion is True:
if ( if (
stream is True stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
client=client,
json_mode=json_mode,
) )
else: else:
## COMPLETION CALL ## COMPLETION CALL
if client is None or isinstance(client, AsyncHTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
else:
client = client
if ( if (
stream is True stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request") print_verbose("makes anthropic streaming POST request")
data["stream"] = stream data["stream"] = stream
response = requests.post( response = client.post(
api_base, api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
return streaming_response return streaming_response
else: else:
response = requests.post( response = client.post(api_base, headers=headers, data=json.dumps(data))
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200: if response.status_code != 200:
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
return self.process_response( return self._process_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
json_mode=json_mode,
) )
def embedding(self): def embedding(self):

View file

@ -7,7 +7,7 @@ import time
import types import types
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -15,7 +15,14 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -142,7 +149,27 @@ class VertexAIAnthropicConfig:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value: if param == "response_format" and "response_schema" in value:
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore """
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": value["response_schema"]}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params return optional_params
@ -222,6 +249,7 @@ def completion(
optional_params: dict, optional_params: dict,
custom_prompt_dict: dict, custom_prompt_dict: dict,
headers: Optional[dict], headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None, vertex_credentials=None,
@ -301,6 +329,8 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=vertex_headers, headers=vertex_headers,
client=client,
timeout=timeout,
) )
except Exception as e: except Exception as e:

View file

@ -1528,6 +1528,8 @@ def completion(
api_key=api_key, api_key=api_key,
logging_obj=logging, logging_obj=logging,
headers=headers, headers=headers,
timeout=timeout,
client=client,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion == True:
## LOGGING ## LOGGING
@ -2046,7 +2048,10 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
headers=headers, headers=headers,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
) )
else: else:
model_response = vertex_ai.completion( model_response = vertex_ai.completion(
model=model, model=model,

View file

@ -1,5 +1,13 @@
model_list: model_list:
- model_name: llama-3 - model_name: bad-azure-model
litellm_params: litellm_params:
model: gpt-4 model: azure/chatgpt-v-2
request_timeout: 1 azure_ad_token: ""
api_base: os.environ/AZURE_API_BASE
- model_name: good-openai-model
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]

View file

@ -1128,6 +1128,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
return mock_response return mock_response
def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20240620",
"content": [
{
"type": "tool_use",
"id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB",
"name": "json_tool_call",
"input": {
"values": [
{"recipe_name": "Chocolate Chip Cookies"},
{"recipe_name": "Oatmeal Raisin Cookies"},
{"recipe_name": "Peanut Butter Cookies"},
{"recipe_name": "Snickerdoodle Cookies"},
{"recipe_name": "Sugar Cookies"},
]
},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
"usage": {"input_tokens": 368, "output_tokens": 118},
}
return mock_response
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
@ -1183,11 +1216,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
return mock_response return mock_response
def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20240620",
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 368, "output_tokens": 118},
}
return mock_response
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, vertex_location, supports_response_schema", "model, vertex_location, supports_response_schema",
[ [
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1231,12 +1282,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
httpx_response = MagicMock() httpx_response = MagicMock()
if invalid_response is True: if invalid_response is True:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response if "claude" in model:
httpx_response.side_effect = (
vertex_httpx_mock_post_invalid_schema_response_anthropic
)
else:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
else: else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response if "claude" in model:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
with patch.object(client, "post", new=httpx_response) as mock_call: with patch.object(client, "post", new=httpx_response) as mock_call:
print("SENDING CLIENT POST={}".format(client.post))
try: try:
_ = completion( resp = completion(
model=model, model=model,
messages=messages, messages=messages,
response_format={ response_format={
@ -1247,30 +1307,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
vertex_location=vertex_location, vertex_location=vertex_location,
client=client, client=client,
) )
print("Received={}".format(resp))
if invalid_response is True and enforce_validation is True: if invalid_response is True and enforce_validation is True:
pytest.fail("Expected this to fail") pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e: except litellm.JSONSchemaValidationError as e:
if invalid_response is False and "claude-3" not in model: if invalid_response is False:
pytest.fail("Expected this to pass. Got={}".format(e)) pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once() mock_call.assert_called_once()
print(mock_call.call_args.kwargs) if "claude" not in model:
print(mock_call.call_args.kwargs["json"]["generationConfig"]) print(mock_call.call_args.kwargs)
print(mock_call.call_args.kwargs["json"]["generationConfig"])
if supports_response_schema: if supports_response_schema:
assert ( assert (
"response_schema" "response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"] in mock_call.call_args.kwargs["json"]["generationConfig"]
) )
else: else:
assert ( assert (
"response_schema" "response_schema"
not in mock_call.call_args.kwargs["json"]["generationConfig"] not in mock_call.call_args.kwargs["json"]["generationConfig"]
) )
assert ( assert (
"Use this JSON schema:" "Use this JSON schema:"
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][
) "text"
]
)
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",