mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix: cohere tool results
This commit is contained in:
parent
f74a43aa78
commit
0db7fa3fd8
1 changed files with 61 additions and 19 deletions
|
@ -3,8 +3,14 @@ import requests, traceback
|
|||
import json, re, xml.etree.ElementTree as ET
|
||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from typing import Optional, Any
|
||||
from typing import List
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
import litellm
|
||||
|
||||
|
||||
|
@ -430,8 +436,10 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
|||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
|
||||
|
||||
### IBM Granite
|
||||
|
||||
|
||||
def ibm_granite_pt(messages: list):
|
||||
"""
|
||||
IBM's Granite models uses the template:
|
||||
|
@ -442,21 +450,22 @@ def ibm_granite_pt(messages: list):
|
|||
return custom_prompt(
|
||||
messages=messages,
|
||||
role_dict={
|
||||
'system': {
|
||||
'pre_message': '<|system|>\n',
|
||||
'post_message': '\n',
|
||||
"system": {
|
||||
"pre_message": "<|system|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "<|user|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "<|assistant|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
'user': {
|
||||
'pre_message': '<|user|>\n',
|
||||
'post_message': '\n',
|
||||
},
|
||||
'assistant': {
|
||||
'pre_message': '<|assistant|>\n',
|
||||
'post_message': '\n',
|
||||
}
|
||||
}
|
||||
).strip()
|
||||
|
||||
|
||||
### ANTHROPIC ###
|
||||
|
||||
|
||||
|
@ -1043,6 +1052,30 @@ def get_system_prompt(messages):
|
|||
return system_prompt, messages
|
||||
|
||||
|
||||
def convert_to_documents(
|
||||
observations: Any,
|
||||
) -> List[MutableMapping]:
|
||||
"""Converts observations into a 'document' dict"""
|
||||
documents: List[MutableMapping] = []
|
||||
if isinstance(observations, str):
|
||||
# strings are turned into a key/value pair and a key of 'output' is added.
|
||||
observations = [{"output": observations}]
|
||||
elif isinstance(observations, Mapping):
|
||||
# single mappings are transformed into a list to simplify the rest of the code.
|
||||
observations = [observations]
|
||||
elif not isinstance(observations, Sequence):
|
||||
# all other types are turned into a key/value pair within a list
|
||||
observations = [{"output": observations}]
|
||||
|
||||
for doc in observations:
|
||||
if not isinstance(doc, Mapping):
|
||||
# types that aren't Mapping are turned into a key/value pair.
|
||||
doc = {"output": doc}
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def convert_openai_message_to_cohere_tool_result(message):
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
|
@ -1084,7 +1117,7 @@ def convert_openai_message_to_cohere_tool_result(message):
|
|||
"parameters": {"location": "San Francisco, CA"},
|
||||
"generation_id": tool_call_id,
|
||||
},
|
||||
"outputs": [content],
|
||||
"outputs": convert_to_documents(content),
|
||||
}
|
||||
return cohere_tool_result
|
||||
|
||||
|
@ -1097,7 +1130,7 @@ def cohere_message_pt(messages: list):
|
|||
if message["role"] == "tool":
|
||||
tool_result = convert_openai_message_to_cohere_tool_result(message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
elif message.get("content"):
|
||||
prompt += message["content"] + "\n\n"
|
||||
prompt = prompt.rstrip()
|
||||
return prompt, tool_results
|
||||
|
@ -1396,9 +1429,18 @@ def prompt_factory(
|
|||
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
||||
return custom_prompt(
|
||||
role_dict={
|
||||
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
"system": {
|
||||
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
},
|
||||
messages=messages,
|
||||
initial_prompt_value="<|begin_of_text|>",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue