fix(bedrock.py): support anthropic messages api on bedrock (claude-3)

This commit is contained in:
Krrish Dholakia 2024-03-04 17:15:35 -08:00
parent 0ac652a771
commit 478307d4cf
15 changed files with 381 additions and 307 deletions

View file

@ -5,7 +5,13 @@ import time
from typing import Callable, Optional, Any, Union, List
import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
@ -81,7 +87,7 @@ class AmazonAnthropicClaude3Config:
"""
max_tokens: Optional[int] = litellm.max_tokens
anthropic_version: Optional[str] = None
anthropic_version: Optional[str] = "bedrock-2023-05-31"
def __init__(
self,
@ -111,6 +117,15 @@ class AmazonAnthropicClaude3Config:
and v is not None
}
def get_supported_openai_params(self):
return ["max_tokens"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
return optional_params
class AmazonAnthropicConfig:
"""
@ -165,6 +180,25 @@ class AmazonAnthropicConfig:
and v is not None
}
def get_supported_openai_params(
self,
):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
return optional_params
class AmazonCohereConfig:
"""
@ -664,7 +698,20 @@ def completion(
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
if model == "anthropic.claude-3":
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: Optional[int] = None
for idx, message in enumerate(messages):
if message["role"] == "system":
inference_params["system"] = message["content"]
system_prompt_idx = idx
break
if system_prompt_idx is not None:
messages.pop(system_prompt_idx)
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
@ -672,7 +719,17 @@ def completion(
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
## Handle Tool Calling
if "tools" in inference_params:
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
data = json.dumps({"messages": messages, **inference_params})
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
@ -838,8 +895,12 @@ def completion(
if provider == "ai21":
outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic":
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
if model.startswith("anthropic.claude-3"):
outputText = response_body.get("content")[0].get("text", None)
model_response["finish_reason"] = response_body["stop_reason"]
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere":
outputText = response_body["generations"][0]["text"]
elif provider == "meta":