mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(bedrock.py): support anthropic messages api on bedrock (claude-3)
This commit is contained in:
parent
0ac652a771
commit
478307d4cf
15 changed files with 381 additions and 307 deletions
|
@ -5,7 +5,13 @@ import time
|
|||
from typing import Callable, Optional, Any, Union, List
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -81,7 +87,7 @@ class AmazonAnthropicClaude3Config:
|
|||
"""
|
||||
|
||||
max_tokens: Optional[int] = litellm.max_tokens
|
||||
anthropic_version: Optional[str] = None
|
||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -111,6 +117,15 @@ class AmazonAnthropicClaude3Config:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["max_tokens"]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
"""
|
||||
|
@ -165,6 +180,25 @@ class AmazonAnthropicConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonCohereConfig:
|
||||
"""
|
||||
|
@ -664,7 +698,20 @@ def completion(
|
|||
inference_params = copy.deepcopy(optional_params)
|
||||
stream = inference_params.pop("stream", False)
|
||||
if provider == "anthropic":
|
||||
if model == "anthropic.claude-3":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_idx: Optional[int] = None
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
inference_params["system"] = message["content"]
|
||||
system_prompt_idx = idx
|
||||
break
|
||||
if system_prompt_idx is not None:
|
||||
messages.pop(system_prompt_idx)
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -672,7 +719,17 @@ def completion(
|
|||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
inference_params["system"] = (
|
||||
inference_params.get("system", "\n")
|
||||
+ tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
inference_params.pop("tools")
|
||||
data = json.dumps({"messages": messages, **inference_params})
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
|
@ -838,8 +895,12 @@ def completion(
|
|||
if provider == "ai21":
|
||||
outputText = response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "anthropic":
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
outputText = response_body.get("content")[0].get("text", None)
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
else:
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
elif provider == "cohere":
|
||||
outputText = response_body["generations"][0]["text"]
|
||||
elif provider == "meta":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue