fix: cohere tool results

This commit is contained in:
alisalim17 2024-04-29 14:20:24 +04:00
parent f74a43aa78
commit 0db7fa3fd8

View file

@ -3,8 +3,14 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Optional, Any
from typing import List
from typing import (
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
import litellm
@ -430,8 +436,10 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
prompt = default_pt(messages)
return prompt
### IBM Granite
def ibm_granite_pt(messages: list):
"""
IBM's Granite models uses the template:
@ -442,21 +450,22 @@ def ibm_granite_pt(messages: list):
return custom_prompt(
messages=messages,
role_dict={
'system': {
'pre_message': '<|system|>\n',
'post_message': '\n',
"system": {
"pre_message": "<|system|>\n",
"post_message": "\n",
},
'user': {
'pre_message': '<|user|>\n',
'post_message': '\n',
"user": {
"pre_message": "<|user|>\n",
"post_message": "\n",
},
'assistant': {
'pre_message': '<|assistant|>\n',
'post_message': '\n',
}
}
"assistant": {
"pre_message": "<|assistant|>\n",
"post_message": "\n",
},
},
).strip()
### ANTHROPIC ###
@ -1043,6 +1052,30 @@ def get_system_prompt(messages):
return system_prompt, messages
def convert_to_documents(
observations: Any,
) -> List[MutableMapping]:
"""Converts observations into a 'document' dict"""
documents: List[MutableMapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
def convert_openai_message_to_cohere_tool_result(message):
"""
OpenAI message with a tool result looks like:
@ -1084,7 +1117,7 @@ def convert_openai_message_to_cohere_tool_result(message):
"parameters": {"location": "San Francisco, CA"},
"generation_id": tool_call_id,
},
"outputs": [content],
"outputs": convert_to_documents(content),
}
return cohere_tool_result
@ -1097,7 +1130,7 @@ def cohere_message_pt(messages: list):
if message["role"] == "tool":
tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result)
else:
elif message.get("content"):
prompt += message["content"] + "\n\n"
prompt = prompt.rstrip()
return prompt, tool_results
@ -1396,9 +1429,18 @@ def prompt_factory(
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return custom_prompt(
role_dict={
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"},
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"},
"system": {
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"user": {
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"assistant": {
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
},
messages=messages,
initial_prompt_value="<|begin_of_text|>",