feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls

allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
Krrish Dholakia 2024-07-18 16:57:38 -07:00
parent f8bdfe7cc3
commit f2401d6d5e
6 changed files with 189 additions and 48 deletions

View file

@ -16,6 +16,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
def _process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
messages: List,
print_verbose,
encoding,
json_mode: bool,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
)
else:
text_content = ""
tool_calls = []
for content in completion_response["content"]:
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
)
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
## HANDLE JSON MODE - anthropic returns single function call
if json_mode and len(tool_calls) == 1:
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
completion_response["stop_reason"] = "stop"
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
_is_function_call,
data: dict,
optional_params: dict,
json_mode: bool,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client()
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
)
raise e
return self.process_response(
return self._process_response(
model=model,
response=response,
model_response=model_response,
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
json_mode=json_mode,
)
def completion(
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
api_key,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers, model)
_is_function_call = False
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
if "input_schema" in tool: # assume in anthropic format
anthropic_tools.append(tool)
else: # assume openai tool call
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
json_mode: bool = optional_params.pop("json_mode", False)
data = {
"messages": messages,
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if acompletion is True:
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
json_mode=json_mode,
)
else:
## COMPLETION CALL
if client is None or isinstance(client, AsyncHTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
else:
client = client
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
return streaming_response
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
return self._process_response(
model=model,
response=response,
model_response=model_response,
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
json_mode=json_mode,
)
def embedding(self):