mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat(vertex_ai_anthropic.py): support response_schema for vertex ai anthropic calls
allows passing response_schema for anthropic calls. supports schema validation.
This commit is contained in:
parent
f8bdfe7cc3
commit
f2401d6d5e
6 changed files with 189 additions and 48 deletions
|
@ -16,6 +16,7 @@ from litellm import verbose_logger
|
|||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
|
@ -538,7 +539,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
def _process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
|
@ -551,6 +552,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
json_mode: bool,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -574,27 +576,40 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
)
|
||||
else:
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
ChatCompletionToolCallChunk(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
),
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
if json_mode and len(tool_calls) == 1:
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||
"arguments"
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
values: Optional[dict] = args.get("values")
|
||||
if values is not None:
|
||||
_message = litellm.Message(content=json.dumps(values))
|
||||
completion_response["stop_reason"] = "stop"
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
|
@ -687,9 +702,11 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
_is_function_call,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
json_mode: bool,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
async_handler = _get_async_httpx_client()
|
||||
|
||||
|
@ -705,7 +722,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
)
|
||||
raise e
|
||||
|
||||
return self.process_response(
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
|
@ -717,6 +734,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -731,10 +749,12 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers, model)
|
||||
_is_function_call = False
|
||||
|
@ -787,14 +807,18 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
if "input_schema" in tool: # assume in anthropic format
|
||||
anthropic_tools.append(tool)
|
||||
else: # assume openai tool call
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
|
||||
optional_params["tools"] = anthropic_tools
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
|
@ -815,7 +839,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
|
@ -857,15 +881,21 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
|
@ -889,15 +919,13 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
return streaming_response
|
||||
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
|
@ -909,6 +937,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue