diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index c51dc89be5..e1fa354c63 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -3,8 +3,14 @@ import requests, traceback import json, re, xml.etree.ElementTree as ET from jinja2 import Template, exceptions, meta, BaseLoader from jinja2.sandbox import ImmutableSandboxedEnvironment -from typing import Optional, Any -from typing import List +from typing import ( + Any, + List, + Mapping, + MutableMapping, + Optional, + Sequence, +) import litellm @@ -430,8 +436,10 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): prompt = default_pt(messages) return prompt + ### IBM Granite + def ibm_granite_pt(messages: list): """ IBM's Granite models uses the template: @@ -440,23 +448,24 @@ def ibm_granite_pt(messages: list): See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models """ return custom_prompt( - messages=messages, + messages=messages, role_dict={ - 'system': { - 'pre_message': '<|system|>\n', - 'post_message': '\n', + "system": { + "pre_message": "<|system|>\n", + "post_message": "\n", }, - 'user': { - 'pre_message': '<|user|>\n', - 'post_message': '\n', + "user": { + "pre_message": "<|user|>\n", + "post_message": "\n", }, - 'assistant': { - 'pre_message': '<|assistant|>\n', - 'post_message': '\n', - } - } + "assistant": { + "pre_message": "<|assistant|>\n", + "post_message": "\n", + }, + }, ).strip() + ### ANTHROPIC ### @@ -1043,6 +1052,30 @@ def get_system_prompt(messages): return system_prompt, messages +def convert_to_documents( + observations: Any, +) -> List[MutableMapping]: + """Converts observations into a 'document' dict""" + documents: List[MutableMapping] = [] + if isinstance(observations, str): + # strings are turned into a key/value pair and a key of 'output' is added. + observations = [{"output": observations}] + elif isinstance(observations, Mapping): + # single mappings are transformed into a list to simplify the rest of the code. + observations = [observations] + elif not isinstance(observations, Sequence): + # all other types are turned into a key/value pair within a list + observations = [{"output": observations}] + + for doc in observations: + if not isinstance(doc, Mapping): + # types that aren't Mapping are turned into a key/value pair. + doc = {"output": doc} + documents.append(doc) + + return documents + + def convert_openai_message_to_cohere_tool_result(message): """ OpenAI message with a tool result looks like: @@ -1084,7 +1117,7 @@ def convert_openai_message_to_cohere_tool_result(message): "parameters": {"location": "San Francisco, CA"}, "generation_id": tool_call_id, }, - "outputs": [content], + "outputs": convert_to_documents(content), } return cohere_tool_result @@ -1097,7 +1130,7 @@ def cohere_message_pt(messages: list): if message["role"] == "tool": tool_result = convert_openai_message_to_cohere_tool_result(message) tool_results.append(tool_result) - else: + elif message.get("content"): prompt += message["content"] + "\n\n" prompt = prompt.rstrip() return prompt, tool_results @@ -1396,9 +1429,18 @@ def prompt_factory( # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ return custom_prompt( role_dict={ - "system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, - "user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"}, - "assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "system": { + "pre_message": "<|start_header_id|>system<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "user": { + "pre_message": "<|start_header_id|>user<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "assistant": { + "pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, }, messages=messages, initial_prompt_value="<|begin_of_text|>",