Move tool definitions from system prompt to parameter and refactor tool calling parse

This commit is contained in:
Zihao Li 2024-04-05 16:01:40 +08:00
parent b0d80de14d
commit d2cf9d2cf1
2 changed files with 38 additions and 45 deletions

View file

@ -118,7 +118,6 @@ def completion(
): ):
headers = validate_environment(api_key, headers) headers = validate_environment(api_key, headers)
_is_function_call = False _is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params) optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict: if model in custom_prompt_dict:
@ -162,17 +161,15 @@ def completion(
## Handle Tool Calling ## Handle Tool Calling
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]: for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get( new_tool = tool["function"]
"parameters", None new_tool["input_schema"] = new_tool.pop("parameters") # rename key
) anthropic_tools.append(new_tool)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"] optional_params["tools"] = anthropic_tools
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
@ -195,9 +192,9 @@ def completion(
print_verbose(f"_is_function_call: {_is_function_call}") print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL ## COMPLETION CALL
if ( if (
stream is not None and stream == True and _is_function_call == False stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request") print_verbose("makes anthropic streaming POST request")
data["stream"] = stream data["stream"] = stream
response = requests.post( response = requests.post(
api_base, api_base,
@ -245,46 +242,40 @@ def completion(
status_code=response.status_code, status_code=response.status_code,
) )
else: else:
text_content = completion_response["content"][0].get("text", None) text_content = ""
## TOOL CALLING - OUTPUT PARSE tool_calls = []
if text_content is not None and contains_tag("invoke", text_content): for content in completion_response["content"]:
function_name = extract_between_tags("tool_name", text_content)[0] if content["type"] == "text":
function_arguments_str = extract_between_tags("invoke", text_content)[ text_content += content["text"]
0 ## TOOL CALLING
].strip() elif content["type"] == "tool_use":
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>" tool_calls.append(
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
{ {
"id": f"call_{uuid.uuid4()}", "id": content["id"],
"type": "function", "type": "function",
"function": { "function": {
"name": function_name, "name": content["name"],
"arguments": json.dumps(function_arguments), "arguments": json.dumps(content["input"]),
}, },
} }
], )
content=None,
) _message = litellm.Message(
model_response.choices[0].message = _message # type: ignore tool_calls=tool_calls,
model_response._hidden_params["original_response"] = ( content=text_content or None,
text_content # allow user to access raw anthropic tool calling response )
) model_response.choices[0].message = _message # type: ignore
else: model_response._hidden_params["original_response"] = completion_response[
model_response.choices[0].message.content = text_content # type: ignore "content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True: if _is_function_call and stream:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator # return an iterator
streaming_model_response = ModelResponse(stream=True) streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ streaming_model_response.choices[0].finish_reason = model_response.choices[
@ -318,7 +309,7 @@ def completion(
model_response=streaming_model_response model_response=streaming_model_response
) )
print_verbose( print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -337,7 +328,7 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -207,6 +207,8 @@ def map_finish_reason(
return "stop" return "stop"
elif finish_reason == "max_tokens": # anthropic elif finish_reason == "max_tokens": # anthropic
return "length" return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
return finish_reason return finish_reason