mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(bedrock.py): support custom prompt templates for all providers
Fixes https://github.com/BerriAI/litellm/issues/4239
This commit is contained in:
parent
51596907aa
commit
71a2fabc6a
3 changed files with 43 additions and 43 deletions
|
@ -1,23 +1,27 @@
|
|||
import json, copy, types
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
import time, uuid
|
||||
from typing import Callable, Optional, Any, Union, List
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.utils import (
|
||||
get_secret,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.types.utils import ImageResponse, ModelResponse, Usage
|
||||
from litellm.utils import get_secret
|
||||
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
|
@ -726,38 +730,31 @@ def init_bedrock_client(
|
|||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
chat_template_provider = ["anthropic", "amazon", "mistral", "meta"]
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
if provider in chat_template_provider:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue