Move tool definitions from system prompt to parameter and refactor tool calling parse

This commit is contained in:
Zihao Li 2024-04-05 16:01:40 +08:00
parent b0d80de14d
commit d2cf9d2cf1
2 changed files with 38 additions and 45 deletions

View file

@ -118,7 +118,6 @@ def completion(
):
headers = validate_environment(api_key, headers)
_is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
@ -162,17 +161,15 @@ def completion(
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
@ -195,9 +192,9 @@ def completion(
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL
if (
stream is not None and stream == True and _is_function_call == False
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
@ -245,46 +242,40 @@ def completion(
status_code=response.status_code,
)
else:
text_content = completion_response["content"][0].get("text", None)
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": f"call_{uuid.uuid4()}",
"id": content["id"],
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
],
content=None,
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
text_content # allow user to access raw anthropic tool calling response
)
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
@ -318,7 +309,7 @@ def completion(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -337,7 +328,7 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
return model_response

View file

@ -207,6 +207,8 @@ def map_finish_reason(
return "stop"
elif finish_reason == "max_tokens": # anthropic
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
return finish_reason