feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls

allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
Krrish Dholakia 2024-07-18 16:57:38 -07:00
parent f8bdfe7cc3
commit f2401d6d5e
6 changed files with 189 additions and 48 deletions

View file

@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str):
from litellm import JSONSchemaValidationError
response_dict = json.loads(response)
try:
response_dict = json.loads(response)
except json.JSONDecodeError:
raise JSONSchemaValidationError(
model="", llm_provider="", raw_response=response, schema=response
)
try:
validate(response_dict, schema=schema)

View file

@ -16,6 +16,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
def _process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
messages: List,
print_verbose,
encoding,
json_mode: bool,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
)
else:
text_content = ""
tool_calls = []
for content in completion_response["content"]:
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
)
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
## HANDLE JSON MODE - anthropic returns single function call
if json_mode and len(tool_calls) == 1:
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
completion_response["stop_reason"] = "stop"
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
_is_function_call,
data: dict,
optional_params: dict,
json_mode: bool,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client()
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
)
raise e
return self.process_response(
return self._process_response(
model=model,
response=response,
model_response=model_response,
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
json_mode=json_mode,
)
def completion(
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
api_key,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers, model)
_is_function_call = False
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
if "input_schema" in tool: # assume in anthropic format
anthropic_tools.append(tool)
else: # assume openai tool call
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
json_mode: bool = optional_params.pop("json_mode", False)
data = {
"messages": messages,
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if acompletion is True:
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
json_mode=json_mode,
)
else:
## COMPLETION CALL
if client is None or isinstance(client, AsyncHTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
else:
client = client
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
return streaming_response
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
return self._process_response(
model=model,
response=response,
model_response=model_response,
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
json_mode=json_mode,
)
def embedding(self):

View file

@ -7,7 +7,7 @@ import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -15,7 +15,14 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -142,7 +149,27 @@ class VertexAIAnthropicConfig:
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value:
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": value["response_schema"]}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params
@ -222,6 +249,7 @@ def completion(
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
@ -301,6 +329,8 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=vertex_headers,
client=client,
timeout=timeout,
)
except Exception as e:

View file

@ -1528,6 +1528,8 @@ def completion(
api_key=api_key,
logging_obj=logging,
headers=headers,
timeout=timeout,
client=client,
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
@ -2046,7 +2048,10 @@ def completion(
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
else:
model_response = vertex_ai.completion(
model=model,

View file

@ -1,5 +1,13 @@
model_list:
- model_name: llama-3
- model_name: bad-azure-model
litellm_params:
model: gpt-4
request_timeout: 1
model: azure/chatgpt-v-2
azure_ad_token: ""
api_base: os.environ/AZURE_API_BASE
- model_name: good-openai-model
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]

View file

@ -1128,6 +1128,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
return mock_response
def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20240620",
"content": [
{
"type": "tool_use",
"id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB",
"name": "json_tool_call",
"input": {
"values": [
{"recipe_name": "Chocolate Chip Cookies"},
{"recipe_name": "Oatmeal Raisin Cookies"},
{"recipe_name": "Peanut Butter Cookies"},
{"recipe_name": "Snickerdoodle Cookies"},
{"recipe_name": "Sugar Cookies"},
]
},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
"usage": {"input_tokens": 368, "output_tokens": 118},
}
return mock_response
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
@ -1183,11 +1216,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
return mock_response
def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20240620",
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 368, "output_tokens": 118},
}
return mock_response
@pytest.mark.parametrize(
"model, vertex_location, supports_response_schema",
[
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False),
],
)
@pytest.mark.parametrize(
@ -1231,12 +1282,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
httpx_response = MagicMock()
if invalid_response is True:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
if "claude" in model:
httpx_response.side_effect = (
vertex_httpx_mock_post_invalid_schema_response_anthropic
)
else:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
if "claude" in model:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
with patch.object(client, "post", new=httpx_response) as mock_call:
print("SENDING CLIENT POST={}".format(client.post))
try:
_ = completion(
resp = completion(
model=model,
messages=messages,
response_format={
@ -1247,30 +1307,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
vertex_location=vertex_location,
client=client,
)
print("Received={}".format(resp))
if invalid_response is True and enforce_validation is True:
pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e:
if invalid_response is False and "claude-3" not in model:
if invalid_response is False:
pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once()
print(mock_call.call_args.kwargs)
print(mock_call.call_args.kwargs["json"]["generationConfig"])
if "claude" not in model:
print(mock_call.call_args.kwargs)
print(mock_call.call_args.kwargs["json"]["generationConfig"])
if supports_response_schema:
assert (
"response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
else:
assert (
"response_schema"
not in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
"Use this JSON schema:"
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
)
if supports_response_schema:
assert (
"response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
else:
assert (
"response_schema"
not in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
"Use this JSON schema:"
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][
"text"
]
)
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",