mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Move tool definitions from system prompt to parameter and refactor tool calling parse
This commit is contained in:
parent
b0d80de14d
commit
d2cf9d2cf1
2 changed files with 38 additions and 45 deletions
|
@ -118,7 +118,6 @@ def completion(
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key, headers)
|
headers = validate_environment(api_key, headers)
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
json_schemas: dict = {}
|
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
optional_params = copy.deepcopy(optional_params)
|
optional_params = copy.deepcopy(optional_params)
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
|
@ -162,17 +161,15 @@ def completion(
|
||||||
## Handle Tool Calling
|
## Handle Tool Calling
|
||||||
if "tools" in optional_params:
|
if "tools" in optional_params:
|
||||||
_is_function_call = True
|
_is_function_call = True
|
||||||
|
headers["anthropic-beta"] = "tools-2024-04-04"
|
||||||
|
|
||||||
|
anthropic_tools = []
|
||||||
for tool in optional_params["tools"]:
|
for tool in optional_params["tools"]:
|
||||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
new_tool = tool["function"]
|
||||||
"parameters", None
|
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||||
)
|
anthropic_tools.append(new_tool)
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
|
||||||
tools=optional_params["tools"]
|
optional_params["tools"] = anthropic_tools
|
||||||
)
|
|
||||||
optional_params["system"] = (
|
|
||||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
|
||||||
) # add the anthropic tool calling prompt to the system prompt
|
|
||||||
optional_params.pop("tools")
|
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
@ -195,9 +192,9 @@ def completion(
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if (
|
if (
|
||||||
stream is not None and stream == True and _is_function_call == False
|
stream and not _is_function_call
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
print_verbose(f"makes anthropic streaming POST request")
|
print_verbose("makes anthropic streaming POST request")
|
||||||
data["stream"] = stream
|
data["stream"] = stream
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
api_base,
|
api_base,
|
||||||
|
@ -245,46 +242,40 @@ def completion(
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_content = completion_response["content"][0].get("text", None)
|
text_content = ""
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
tool_calls = []
|
||||||
if text_content is not None and contains_tag("invoke", text_content):
|
for content in completion_response["content"]:
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
if content["type"] == "text":
|
||||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
text_content += content["text"]
|
||||||
0
|
## TOOL CALLING
|
||||||
].strip()
|
elif content["type"] == "tool_use":
|
||||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
tool_calls.append(
|
||||||
function_arguments = parse_xml_params(
|
|
||||||
function_arguments_str,
|
|
||||||
json_schema=json_schemas.get(
|
|
||||||
function_name, None
|
|
||||||
), # check if we have a json schema for this function name
|
|
||||||
)
|
|
||||||
_message = litellm.Message(
|
|
||||||
tool_calls=[
|
|
||||||
{
|
{
|
||||||
"id": f"call_{uuid.uuid4()}",
|
"id": content["id"],
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": function_name,
|
"name": content["name"],
|
||||||
"arguments": json.dumps(function_arguments),
|
"arguments": json.dumps(content["input"]),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
content=None,
|
|
||||||
)
|
_message = litellm.Message(
|
||||||
model_response.choices[0].message = _message # type: ignore
|
tool_calls=tool_calls,
|
||||||
model_response._hidden_params["original_response"] = (
|
content=text_content or None,
|
||||||
text_content # allow user to access raw anthropic tool calling response
|
)
|
||||||
)
|
model_response.choices[0].message = _message # type: ignore
|
||||||
else:
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
"content"
|
||||||
|
] # allow user to access raw anthropic tool calling response
|
||||||
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||||
if _is_function_call == True and stream is not None and stream == True:
|
if _is_function_call and stream:
|
||||||
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
# return an iterator
|
# return an iterator
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||||
|
@ -318,7 +309,7 @@ def completion(
|
||||||
model_response=streaming_model_response
|
model_response=streaming_model_response
|
||||||
)
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -337,7 +328,7 @@ def completion(
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -207,6 +207,8 @@ def map_finish_reason(
|
||||||
return "stop"
|
return "stop"
|
||||||
elif finish_reason == "max_tokens": # anthropic
|
elif finish_reason == "max_tokens": # anthropic
|
||||||
return "length"
|
return "length"
|
||||||
|
elif finish_reason == "tool_use": # anthropic
|
||||||
|
return "tool_calls"
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue