mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) - Add /bedrock/invoke
support for all Anthropic models (#8383)
* use anthropic transformation for bedrock/invoke * use anthropic transforms for bedrock invoke claude * TestBedrockInvokeClaudeJson * add AmazonAnthropicClaudeStreamDecoder * pass bedrock_invoke_provider to make_call * fix _get_base_bedrock_model * fix get_bedrock_route * fix bedrock routing * fixes for bedrock invoke * test_all_model_configs * fix AWSEventStreamDecoder linting * fix code qa * test_bedrock_get_base_model * test_get_model_info_bedrock_models * test_bedrock_base_model_helper * test_bedrock_route_detection
This commit is contained in:
parent
1dd3713f1a
commit
b242c66a3b
15 changed files with 386 additions and 262 deletions
|
@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
|
|||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the base model name from the given model name.
|
||||
|
||||
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _dict_to_response_format_helper(
|
||||
response_format: dict, ref_template: Optional[str] = None
|
||||
|
|
|
@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
|
|||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||||
|
||||
from ..common_utils import (
|
||||
AmazonBedrockGlobalConfig,
|
||||
BedrockError,
|
||||
get_bedrock_tool_name,
|
||||
)
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
|
||||
|
||||
|
||||
class AmazonConverseConfig(BaseConfig):
|
||||
|
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
]
|
||||
|
||||
## Filter out 'cross-region' from model name
|
||||
base_model = self._get_base_model(model)
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
if (
|
||||
base_model.startswith("anthropic")
|
||||
|
@ -341,9 +334,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if "top_k" in inference_params:
|
||||
inference_params["topK"] = inference_params.pop("top_k")
|
||||
return InferenceConfig(**inference_params)
|
||||
|
||||
|
||||
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
||||
base_model = self._get_base_model(model)
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
val_top_k = None
|
||||
if "topK" in inference_params:
|
||||
|
@ -352,11 +345,11 @@ class AmazonConverseConfig(BaseConfig):
|
|||
val_top_k = inference_params.pop("top_k")
|
||||
|
||||
if val_top_k:
|
||||
if (base_model.startswith("anthropic")):
|
||||
if base_model.startswith("anthropic"):
|
||||
return {"top_k": val_top_k}
|
||||
if base_model.startswith("amazon.nova"):
|
||||
return {'inferenceConfig': {"topK": val_top_k}}
|
||||
|
||||
return {"inferenceConfig": {"topK": val_top_k}}
|
||||
|
||||
return {}
|
||||
|
||||
def _transform_request_helper(
|
||||
|
@ -393,15 +386,25 @@ class AmazonConverseConfig(BaseConfig):
|
|||
) + ["top_k"]
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
|
||||
total_supported_params = (
|
||||
supported_converse_params
|
||||
+ supported_tool_call_params
|
||||
+ supported_guardrail_params
|
||||
)
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
|
||||
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
||||
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
|
||||
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
|
||||
additional_request_params = {
|
||||
k: v for k, v in inference_params.items() if k not in total_supported_params
|
||||
}
|
||||
inference_params = {
|
||||
k: v for k, v in inference_params.items() if k in total_supported_params
|
||||
}
|
||||
|
||||
# Only set the topK value in for models that support it
|
||||
additional_request_params.update(self._handle_top_k_value(model, inference_params))
|
||||
additional_request_params.update(
|
||||
self._handle_top_k_value(model, inference_params)
|
||||
)
|
||||
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
|
@ -679,41 +682,6 @@ class AmazonConverseConfig(BaseConfig):
|
|||
|
||||
return model_response
|
||||
|
||||
def _supported_cross_region_inference_region(self) -> List[str]:
|
||||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
|
||||
def _get_base_model(self, model: str) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if potential_region in self._supported_cross_region_inference_region():
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
|
|
|
@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
parse_xml_params,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.handler import (
|
||||
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
@ -177,6 +180,7 @@ async def make_call(
|
|||
logging_obj: Logging,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
|
@ -214,6 +218,14 @@ async def make_call(
|
|||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
elif bedrock_invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=False,
|
||||
)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
|
@ -248,6 +260,7 @@ def make_sync_call(
|
|||
logging_obj: Logging,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
|
@ -283,6 +296,12 @@ def make_sync_call(
|
|||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
elif bedrock_invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=True,
|
||||
)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
|
|||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
######## converse bedrock.anthropic mappings ###############
|
||||
elif (
|
||||
"contentBlockIndex" in chunk_data
|
||||
or "stopReason" in chunk_data
|
||||
|
@ -1429,6 +1448,27 @@ class AWSEventStreamDecoder:
|
|||
return chunk.decode() # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
sync_stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
||||
|
||||
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=sync_stream,
|
||||
)
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
|
||||
|
||||
|
||||
class MockResponseIterator: # for returning ai21 streaming responses
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
|
|
|
@ -1,61 +1,34 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config:
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens. Default is 4096
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
|
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
|
|||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
|
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_anthropic_request = litellm.AnthropicConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.AnthropicConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@ import copy
|
|||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
|
@ -13,11 +12,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
|||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
|
@ -194,7 +189,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
json_schemas: dict = {}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
|
@ -223,57 +217,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_idx: list[int] = []
|
||||
system_messages: list[str] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system" and isinstance(
|
||||
message["content"], str
|
||||
):
|
||||
system_messages.append(message["content"])
|
||||
system_prompt_idx.append(idx)
|
||||
if len(system_prompt_idx) > 0:
|
||||
inference_params["system"] = "\n".join(system_messages)
|
||||
messages = [
|
||||
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
) # type: ignore
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
_is_function_call = True
|
||||
for tool in inference_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
inference_params["system"] = (
|
||||
inference_params.get("system", "\n")
|
||||
+ tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
inference_params.pop("tools")
|
||||
request_data = {"messages": messages, **inference_params}
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
|
@ -359,66 +309,19 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
json_schemas: dict = {}
|
||||
_is_function_call = False
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool[
|
||||
"function"
|
||||
].get("parameters", None)
|
||||
outputText = completion_response.get("content")[0].get("text", None)
|
||||
if outputText is not None and contains_tag(
|
||||
"invoke", outputText
|
||||
): # OUTPUT PARSE FUNCTION CALL
|
||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||
function_arguments_str = extract_between_tags(
|
||||
"invoke", outputText
|
||||
)[0].strip()
|
||||
function_arguments_str = (
|
||||
f"<invoke>{function_arguments_str}</invoke>"
|
||||
)
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name)
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response.get("stop_reason", "")
|
||||
)
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usage"]["input_tokens"],
|
||||
completion_tokens=completion_response["usage"]["output_tokens"],
|
||||
total_tokens=completion_response["usage"]["input_tokens"]
|
||||
+ completion_response["usage"]["output_tokens"],
|
||||
)
|
||||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = completion_response["completion"]
|
||||
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"stop_reason"
|
||||
]
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
|
@ -536,6 +439,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -569,6 +473,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -594,10 +499,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 2 scenarions:
|
||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
handles 3 scenarions:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
@ -640,9 +549,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
else:
|
||||
modelId = model
|
||||
|
||||
modelId = modelId.replace("invoke/", "", 1)
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||
|
||||
return modelId
|
||||
|
||||
def _get_aws_region_name(self, optional_params: dict) -> str:
|
||||
|
|
|
@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
|
|||
response_tool_name
|
||||
]
|
||||
return response_tool_name
|
||||
|
||||
|
||||
class BedrockModelInfo(BaseLLMModelInfo):
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("invoke/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if (
|
||||
potential_region
|
||||
in BedrockModelInfo._supported_cross_region_inference_region()
|
||||
):
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in BedrockModelInfo.all_global_regions
|
||||
and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _supported_cross_region_inference_region() -> List[str]:
|
||||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
|
||||
"""
|
||||
Get the bedrock route for the given model.
|
||||
"""
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
if "invoke/" in model:
|
||||
return "invoke"
|
||||
elif "converse_like" in model:
|
||||
return "converse_like"
|
||||
elif "converse/" in model:
|
||||
return "converse"
|
||||
elif base_model in litellm.bedrock_converse_models:
|
||||
return "converse"
|
||||
return "invoke"
|
||||
|
|
|
@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
return model
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
|
|
|
@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
|
|||
return (
|
||||
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
return model
|
||||
|
|
|
@ -68,6 +68,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
get_content_from_model_response,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.realtime_api.main import _realtime_health_check
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
|
||||
if base_model in litellm.bedrock_converse_models or model.startswith(
|
||||
"converse/"
|
||||
):
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
model = model.replace("converse/", "")
|
||||
response = bedrock_converse_chat_completion.completion(
|
||||
model=model,
|
||||
|
@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
client=client,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif "converse_like" in model:
|
||||
elif bedrock_route == "converse_like":
|
||||
model = model.replace("converse_like/", "")
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
|
|
|
@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
|
|||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
LiteLLMResponseObjectHandler,
|
||||
_handle_invalid_parallel_tool_calls,
|
||||
_parse_content_for_reasoning,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
||||
|
@ -111,6 +111,7 @@ from litellm.litellm_core_utils.token_counter import (
|
|||
calculate_img_tokens,
|
||||
get_modified_max_tokens,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.router_utils.get_retry_from_policy import (
|
||||
get_num_retries_from_retry_policy,
|
||||
|
@ -3189,8 +3190,8 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
if base_model in litellm.bedrock_converse_models:
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3203,15 +3204,20 @@ def get_optional_params( # noqa: PLR0915
|
|||
messages=messages,
|
||||
)
|
||||
|
||||
elif "anthropic" in model:
|
||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif "anthropic" in model and bedrock_route == "invoke":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str:
|
|||
return re.sub(r"-\d+$", "", model_name)
|
||||
|
||||
|
||||
def _strip_bedrock_region(model_name) -> str:
|
||||
return litellm.AmazonConverseConfig()._get_base_model(model_name)
|
||||
def _get_base_bedrock_model(model_name) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
return BedrockModelInfo.get_base_model(model_name)
|
||||
|
||||
|
||||
def _strip_openai_finetune_model_name(model_name: str) -> str:
|
||||
|
@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
|
|||
|
||||
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
|
||||
if custom_llm_provider and custom_llm_provider == "bedrock":
|
||||
strip_bedrock_region = _strip_bedrock_region(model_name=model)
|
||||
return strip_bedrock_region
|
||||
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
|
||||
return stripped_bedrock_model
|
||||
elif custom_llm_provider and (
|
||||
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
|
||||
):
|
||||
|
@ -6066,24 +6080,23 @@ class ProviderConfigManager:
|
|||
elif litellm.LlmProviders.PETALS == provider:
|
||||
return litellm.PetalsConfig()
|
||||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
|
||||
if (
|
||||
base_model in litellm.bedrock_converse_models
|
||||
or "converse_like" in model
|
||||
):
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
||||
model
|
||||
)
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
return litellm.AmazonConverseConfig()
|
||||
elif bedrock_provider == "amazon": # amazon titan llms
|
||||
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
||||
return litellm.AmazonTitanConfig()
|
||||
elif (
|
||||
bedrock_provider == "meta" or bedrock_provider == "llama"
|
||||
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
||||
): # amazon / meta llms
|
||||
return litellm.AmazonLlamaConfig()
|
||||
elif bedrock_provider == "ai21": # ai21 llms
|
||||
elif bedrock_invoke_provider == "ai21": # ai21 llms
|
||||
return litellm.AmazonAI21Config()
|
||||
elif bedrock_provider == "cohere": # cohere models on bedrock
|
||||
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
|
||||
return litellm.AmazonCohereConfig()
|
||||
elif bedrock_provider == "mistral": # mistral models on bedrock
|
||||
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
|
||||
return litellm.AmazonMistralConfig()
|
||||
else:
|
||||
return litellm.AmazonInvokeConfig()
|
||||
|
|
|
@ -1265,7 +1265,9 @@ def test_bedrock_cross_region_inference(model):
|
|||
],
|
||||
)
|
||||
def test_bedrock_get_base_model(model, expected_base_model):
|
||||
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
assert BedrockModelInfo.get_base_model(model) == expected_base_model
|
||||
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
|
@ -1982,9 +1984,49 @@ def test_bedrock_mapped_converse_models():
|
|||
|
||||
|
||||
def test_bedrock_base_model_helper():
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
model = "us.amazon.nova-pro-v1:0"
|
||||
litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
assert model == "us.amazon.nova-pro-v1:0"
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
assert base_model == "amazon.nova-pro-v1:0"
|
||||
|
||||
assert (
|
||||
BedrockModelInfo.get_base_model(
|
||||
"invoke/anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
)
|
||||
== "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected_route",
|
||||
[
|
||||
# Test explicit route prefixes
|
||||
("invoke/anthropic.claude-3-sonnet-20240229-v1:0", "invoke"),
|
||||
("converse/anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
|
||||
("converse_like/anthropic.claude-3-sonnet-20240229-v1:0", "converse_like"),
|
||||
# Test models in BEDROCK_CONVERSE_MODELS list
|
||||
("anthropic.claude-3-5-haiku-20241022-v1:0", "converse"),
|
||||
("anthropic.claude-v2", "converse"),
|
||||
("meta.llama3-70b-instruct-v1:0", "converse"),
|
||||
("mistral.mistral-large-2407-v1:0", "converse"),
|
||||
# Test models with region prefixes
|
||||
("us.anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
|
||||
("us.meta.llama3-70b-instruct-v1:0", "converse"),
|
||||
# Test default case (should return "invoke")
|
||||
("amazon.titan-text-express-v1", "invoke"),
|
||||
("cohere.command-text-v14", "invoke"),
|
||||
("cohere.command-r-v1:0", "invoke"),
|
||||
],
|
||||
)
|
||||
def test_bedrock_route_detection(model, expected_route):
|
||||
"""Test all scenarios for BedrockModelInfo.get_bedrock_route"""
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
route = BedrockModelInfo.get_bedrock_route(model)
|
||||
assert (
|
||||
route == expected_route
|
||||
), f"Expected route '{expected_route}' for model '{model}', but got '{route}'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
28
tests/llm_translation/test_bedrock_invoke_claude_json.py
Normal file
28
tests/llm_translation/test_bedrock_invoke_claude_json.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
litellm._turn_on_debug()
|
||||
return {
|
||||
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_non_json_tests(self, request):
|
||||
if not "json" in request.function.__name__.lower():
|
||||
pytest.skip(
|
||||
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||
)
|
|
@ -305,12 +305,16 @@ def test_all_model_configs():
|
|||
|
||||
assert (
|
||||
"max_completion_tokens"
|
||||
in AmazonAnthropicClaude3Config().get_supported_openai_params()
|
||||
in AmazonAnthropicClaude3Config().get_supported_openai_params(
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
)
|
||||
|
||||
assert AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params={"max_completion_tokens": 10},
|
||||
optional_params={},
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
drop_params=False,
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
assert (
|
||||
|
|
|
@ -208,3 +208,11 @@ def test_nova_bedrock_converse():
|
|||
)
|
||||
assert custom_llm_provider == "bedrock"
|
||||
assert model == "amazon.nova-micro-v1:0"
|
||||
|
||||
|
||||
def test_bedrock_invoke_anthropic():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
)
|
||||
assert custom_llm_provider == "bedrock"
|
||||
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
|
|
@ -321,7 +321,7 @@ def test_get_model_info_bedrock_models():
|
|||
"""
|
||||
Check for drift in base model info for bedrock models and regional model info for bedrock models.
|
||||
"""
|
||||
from litellm import AmazonConverseConfig
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
@ -337,7 +337,7 @@ def test_get_model_info_bedrock_models():
|
|||
if any(commitment in k for commitment in potential_commitments):
|
||||
for commitment in potential_commitments:
|
||||
k = k.replace(f"{commitment}/", "")
|
||||
base_model = AmazonConverseConfig()._get_base_model(k)
|
||||
base_model = BedrockModelInfo.get_base_model(k)
|
||||
base_model_info = litellm.model_cost[base_model]
|
||||
for base_model_key, base_model_value in base_model_info.items():
|
||||
if base_model_key.startswith("supports_"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue