forked from phoenix/litellm-mirror
feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls
allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
parent
f8bdfe7cc3
commit
f2401d6d5e
6 changed files with 189 additions and 48 deletions
|
@ -13,7 +13,12 @@ def validate_schema(schema: dict, response: str):
|
|||
|
||||
from litellm import JSONSchemaValidationError
|
||||
|
||||
response_dict = json.loads(response)
|
||||
try:
|
||||
response_dict = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
raise JSONSchemaValidationError(
|
||||
model="", llm_provider="", raw_response=response, schema=response
|
||||
)
|
||||
|
||||
try:
|
||||
validate(response_dict, schema=schema)
|
||||
|
|
|
@ -16,6 +16,7 @@ from litellm import verbose_logger
|
|||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
|
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
def _process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
|
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
json_mode: bool,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
)
|
||||
else:
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
ChatCompletionToolCallChunk(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
),
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
if json_mode and len(tool_calls) == 1:
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||
"arguments"
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
values: Optional[dict] = args.get("values")
|
||||
if values is not None:
|
||||
_message = litellm.Message(content=json.dumps(values))
|
||||
completion_response["stop_reason"] = "stop"
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
|
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
_is_function_call,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
json_mode: bool,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
async_handler = _get_async_httpx_client()
|
||||
|
||||
|
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
)
|
||||
raise e
|
||||
|
||||
return self.process_response(
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
|
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers, model)
|
||||
_is_function_call = False
|
||||
|
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
if "input_schema" in tool: # assume in anthropic format
|
||||
anthropic_tools.append(tool)
|
||||
else: # assume openai tool call
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
|
||||
optional_params["tools"] = anthropic_tools
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
|
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
|
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
|
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
return streaming_response
|
||||
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
|
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ import time
|
|||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
@ -15,7 +15,14 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesTool,
|
||||
AnthropicMessagesToolChoice,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ResponseFormatChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
|
@ -142,7 +149,27 @@ class VertexAIAnthropicConfig:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and "response_schema" in value:
|
||||
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
|
||||
"""
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_tool_choice = None
|
||||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||||
|
||||
_tool = AnthropicMessagesTool(
|
||||
name="json_tool_call",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"values": value["response_schema"]}, # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
|
||||
|
@ -222,6 +249,7 @@ def completion(
|
|||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
|
@ -301,6 +329,8 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=vertex_headers,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -1528,6 +1528,8 @@ def completion(
|
|||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
|
@ -2046,7 +2048,10 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
else:
|
||||
model_response = vertex_ai.completion(
|
||||
model=model,
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
model_list:
|
||||
- model_name: llama-3
|
||||
- model_name: bad-azure-model
|
||||
litellm_params:
|
||||
model: gpt-4
|
||||
request_timeout: 1
|
||||
model: azure/chatgpt-v-2
|
||||
azure_ad_token: ""
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
|
||||
- model_name: good-openai-model
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]
|
||||
|
|
|
@ -1128,6 +1128,39 @@ def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
|||
return mock_response
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_valid_response_anthropic(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_vrtx_01YMnYZrToPPfcmY2myP2gEB",
|
||||
"name": "json_tool_call",
|
||||
"input": {
|
||||
"values": [
|
||||
{"recipe_name": "Chocolate Chip Cookies"},
|
||||
{"recipe_name": "Oatmeal Raisin Cookies"},
|
||||
{"recipe_name": "Peanut Butter Cookies"},
|
||||
{"recipe_name": "Snickerdoodle Cookies"},
|
||||
{"recipe_name": "Sugar Cookies"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 368, "output_tokens": 118},
|
||||
}
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
@ -1183,11 +1216,29 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
|||
return mock_response
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_invalid_schema_response_anthropic(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"id": "msg_vrtx_013Wki5RFQXAspL7rmxRFjZg",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 368, "output_tokens": 118},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, vertex_location, supports_response_schema",
|
||||
[
|
||||
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
||||
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
||||
("vertex_ai/claude-3-5-sonnet@20240620", "us-east5", False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1231,12 +1282,21 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
|||
|
||||
httpx_response = MagicMock()
|
||||
if invalid_response is True:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||
if "claude" in model:
|
||||
httpx_response.side_effect = (
|
||||
vertex_httpx_mock_post_invalid_schema_response_anthropic
|
||||
)
|
||||
else:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||
else:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||
if "claude" in model:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic
|
||||
else:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||
print("SENDING CLIENT POST={}".format(client.post))
|
||||
try:
|
||||
_ = completion(
|
||||
resp = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format={
|
||||
|
@ -1247,30 +1307,34 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
|||
vertex_location=vertex_location,
|
||||
client=client,
|
||||
)
|
||||
print("Received={}".format(resp))
|
||||
if invalid_response is True and enforce_validation is True:
|
||||
pytest.fail("Expected this to fail")
|
||||
except litellm.JSONSchemaValidationError as e:
|
||||
if invalid_response is False and "claude-3" not in model:
|
||||
if invalid_response is False:
|
||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||
|
||||
mock_call.assert_called_once()
|
||||
print(mock_call.call_args.kwargs)
|
||||
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
||||
if "claude" not in model:
|
||||
print(mock_call.call_args.kwargs)
|
||||
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
||||
|
||||
if supports_response_schema:
|
||||
assert (
|
||||
"response_schema"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"response_schema"
|
||||
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
"Use this JSON schema:"
|
||||
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
|
||||
)
|
||||
if supports_response_schema:
|
||||
assert (
|
||||
"response_schema"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"response_schema"
|
||||
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
"Use this JSON schema:"
|
||||
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1][
|
||||
"text"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue