fix(bedrock.py): support custom prompt templates for all providers

Fixes https://github.com/BerriAI/litellm/issues/4239
This commit is contained in:
Krrish Dholakia 2024-06-17 08:28:46 -07:00
parent 3a35a58859
commit 7dd0151f83
3 changed files with 43 additions and 43 deletions

View file

@ -1,12 +1,12 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
# - id: mypy - id: mypy
# name: mypy name: mypy
# entry: python3 -m mypy --ignore-missing-imports entry: python3 -m mypy --ignore-missing-imports
# language: system language: system
# types: [python] types: [python]
# files: ^litellm/ files: ^litellm/
- id: isort - id: isort
name: isort name: isort
entry: isort entry: isort

View file

@ -1,23 +1,27 @@
import json, copy, types import copy
import json
import os import os
import time
import types
import uuid
from enum import Enum from enum import Enum
import time, uuid from typing import Any, Callable, List, Optional, Union
from typing import Callable, Optional, Any, Union, List
import httpx
import litellm import litellm
from litellm.utils import (
get_secret,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.utils import ImageResponse, ModelResponse, Usage from litellm.types.utils import ImageResponse, ModelResponse, Usage
from litellm.utils import get_secret
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
contains_tag, prompt_factory,
) )
import httpx
class BedrockError(Exception): class BedrockError(Exception):
@ -726,7 +730,7 @@ def init_bedrock_client(
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
if provider == "anthropic" or provider == "amazon": chat_template_provider = ["anthropic", "amazon", "mistral", "meta"]
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -737,14 +741,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory( if provider in chat_template_provider:
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory( prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock" model=model, messages=messages, custom_llm_provider="bedrock"
) )

View file

@ -71,6 +71,9 @@ extra_proxy = [
"pynacl" "pynacl"
] ]
[tool.isort]
profile = "black"
[tool.poetry.scripts] [tool.poetry.scripts]
litellm = 'litellm:run_server' litellm = 'litellm:run_server'