fix(bedrock.py): support custom prompt templates for all providers

Fixes https://github.com/BerriAI/litellm/issues/4239
This commit is contained in:
Krrish Dholakia 2024-06-17 08:28:46 -07:00
parent 3a35a58859
commit 7dd0151f83
3 changed files with 43 additions and 43 deletions

View file

@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
- id: isort
name: isort
entry: isort

View file

@ -1,23 +1,27 @@
import json, copy, types
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
import time, uuid
from typing import Callable, Optional, Any, Union, List
from typing import Any, Callable, List, Optional, Union
import httpx
import litellm
from litellm.utils import (
get_secret,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.utils import ImageResponse, ModelResponse, Usage
from litellm.utils import get_secret
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
contains_tag,
prompt_factory,
)
import httpx
class BedrockError(Exception):
@ -726,38 +730,31 @@ def init_bedrock_client(
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
chat_template_provider = ["anthropic", "amazon", "mistral", "meta"]
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
if provider in chat_template_provider:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt

View file

@ -71,6 +71,9 @@ extra_proxy = [
"pynacl"
]
[tool.isort]
profile = "black"
[tool.poetry.scripts]
litellm = 'litellm:run_server'