fix: cohere tool results

This commit is contained in:
alisalim17 2024-04-29 14:20:24 +04:00
parent f74a43aa78
commit 0db7fa3fd8

View file

@ -3,8 +3,14 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Optional, Any from typing import (
from typing import List Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
import litellm import litellm
@ -430,8 +436,10 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
prompt = default_pt(messages) prompt = default_pt(messages)
return prompt return prompt
### IBM Granite ### IBM Granite
def ibm_granite_pt(messages: list): def ibm_granite_pt(messages: list):
""" """
IBM's Granite models uses the template: IBM's Granite models uses the template:
@ -440,23 +448,24 @@ def ibm_granite_pt(messages: list):
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
""" """
return custom_prompt( return custom_prompt(
messages=messages, messages=messages,
role_dict={ role_dict={
'system': { "system": {
'pre_message': '<|system|>\n', "pre_message": "<|system|>\n",
'post_message': '\n', "post_message": "\n",
}, },
'user': { "user": {
'pre_message': '<|user|>\n', "pre_message": "<|user|>\n",
'post_message': '\n', "post_message": "\n",
}, },
'assistant': { "assistant": {
'pre_message': '<|assistant|>\n', "pre_message": "<|assistant|>\n",
'post_message': '\n', "post_message": "\n",
} },
} },
).strip() ).strip()
### ANTHROPIC ### ### ANTHROPIC ###
@ -1043,6 +1052,30 @@ def get_system_prompt(messages):
return system_prompt, messages return system_prompt, messages
def convert_to_documents(
observations: Any,
) -> List[MutableMapping]:
"""Converts observations into a 'document' dict"""
documents: List[MutableMapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
def convert_openai_message_to_cohere_tool_result(message): def convert_openai_message_to_cohere_tool_result(message):
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
@ -1084,7 +1117,7 @@ def convert_openai_message_to_cohere_tool_result(message):
"parameters": {"location": "San Francisco, CA"}, "parameters": {"location": "San Francisco, CA"},
"generation_id": tool_call_id, "generation_id": tool_call_id,
}, },
"outputs": [content], "outputs": convert_to_documents(content),
} }
return cohere_tool_result return cohere_tool_result
@ -1097,7 +1130,7 @@ def cohere_message_pt(messages: list):
if message["role"] == "tool": if message["role"] == "tool":
tool_result = convert_openai_message_to_cohere_tool_result(message) tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result) tool_results.append(tool_result)
else: elif message.get("content"):
prompt += message["content"] + "\n\n" prompt += message["content"] + "\n\n"
prompt = prompt.rstrip() prompt = prompt.rstrip()
return prompt, tool_results return prompt, tool_results
@ -1396,9 +1429,18 @@ def prompt_factory(
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return custom_prompt( return custom_prompt(
role_dict={ role_dict={
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, "system": {
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"}, "pre_message": "<|start_header_id|>system<|end_header_id|>\n",
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"}, "post_message": "<|eot_id|>",
},
"user": {
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"assistant": {
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
}, },
messages=messages, messages=messages,
initial_prompt_value="<|begin_of_text|>", initial_prompt_value="<|begin_of_text|>",