mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Feat) - Add /bedrock/invoke
support for all Anthropic models (#8383)
* use anthropic transformation for bedrock/invoke * use anthropic transforms for bedrock invoke claude * TestBedrockInvokeClaudeJson * add AmazonAnthropicClaudeStreamDecoder * pass bedrock_invoke_provider to make_call * fix _get_base_bedrock_model * fix get_bedrock_route * fix bedrock routing * fixes for bedrock invoke * test_all_model_configs * fix AWSEventStreamDecoder linting * fix code qa * test_bedrock_get_base_model * test_get_model_info_bedrock_models * test_bedrock_base_model_helper * test_bedrock_route_detection
This commit is contained in:
parent
1dd3713f1a
commit
b242c66a3b
15 changed files with 386 additions and 262 deletions
|
@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
|
||||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_model(model: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Returns the base model name from the given model name.
|
||||||
|
|
||||||
|
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||||
|
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _dict_to_response_format_helper(
|
def _dict_to_response_format_helper(
|
||||||
response_format: dict, ref_template: Optional[str] = None
|
response_format: dict, ref_template: Optional[str] = None
|
||||||
|
|
|
@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import ModelResponse, Usage
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||||||
|
|
||||||
from ..common_utils import (
|
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
|
||||||
AmazonBedrockGlobalConfig,
|
|
||||||
BedrockError,
|
|
||||||
get_bedrock_tool_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
global_config = AmazonBedrockGlobalConfig()
|
|
||||||
all_global_regions = global_config.get_all_regions()
|
|
||||||
|
|
||||||
|
|
||||||
class AmazonConverseConfig(BaseConfig):
|
class AmazonConverseConfig(BaseConfig):
|
||||||
|
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
]
|
]
|
||||||
|
|
||||||
## Filter out 'cross-region' from model name
|
## Filter out 'cross-region' from model name
|
||||||
base_model = self._get_base_model(model)
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
base_model.startswith("anthropic")
|
base_model.startswith("anthropic")
|
||||||
|
@ -341,9 +334,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
if "top_k" in inference_params:
|
if "top_k" in inference_params:
|
||||||
inference_params["topK"] = inference_params.pop("top_k")
|
inference_params["topK"] = inference_params.pop("top_k")
|
||||||
return InferenceConfig(**inference_params)
|
return InferenceConfig(**inference_params)
|
||||||
|
|
||||||
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
||||||
base_model = self._get_base_model(model)
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
|
|
||||||
val_top_k = None
|
val_top_k = None
|
||||||
if "topK" in inference_params:
|
if "topK" in inference_params:
|
||||||
|
@ -352,11 +345,11 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
val_top_k = inference_params.pop("top_k")
|
val_top_k = inference_params.pop("top_k")
|
||||||
|
|
||||||
if val_top_k:
|
if val_top_k:
|
||||||
if (base_model.startswith("anthropic")):
|
if base_model.startswith("anthropic"):
|
||||||
return {"top_k": val_top_k}
|
return {"top_k": val_top_k}
|
||||||
if base_model.startswith("amazon.nova"):
|
if base_model.startswith("amazon.nova"):
|
||||||
return {'inferenceConfig': {"topK": val_top_k}}
|
return {"inferenceConfig": {"topK": val_top_k}}
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _transform_request_helper(
|
def _transform_request_helper(
|
||||||
|
@ -393,15 +386,25 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
) + ["top_k"]
|
) + ["top_k"]
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
|
total_supported_params = (
|
||||||
|
supported_converse_params
|
||||||
|
+ supported_tool_call_params
|
||||||
|
+ supported_guardrail_params
|
||||||
|
)
|
||||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||||
|
|
||||||
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
||||||
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
|
additional_request_params = {
|
||||||
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
|
k: v for k, v in inference_params.items() if k not in total_supported_params
|
||||||
|
}
|
||||||
|
inference_params = {
|
||||||
|
k: v for k, v in inference_params.items() if k in total_supported_params
|
||||||
|
}
|
||||||
|
|
||||||
# Only set the topK value in for models that support it
|
# Only set the topK value in for models that support it
|
||||||
additional_request_params.update(self._handle_top_k_value(model, inference_params))
|
additional_request_params.update(
|
||||||
|
self._handle_top_k_value(model, inference_params)
|
||||||
|
)
|
||||||
|
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
inference_params.pop("tools", [])
|
inference_params.pop("tools", [])
|
||||||
|
@ -679,41 +682,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def _supported_cross_region_inference_region(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
|
||||||
"""
|
|
||||||
return ["us", "eu", "apac"]
|
|
||||||
|
|
||||||
def _get_base_model(self, model: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the base model from the given model name.
|
|
||||||
|
|
||||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
|
||||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
|
||||||
"""
|
|
||||||
|
|
||||||
if model.startswith("bedrock/"):
|
|
||||||
model = model.split("/", 1)[1]
|
|
||||||
|
|
||||||
if model.startswith("converse/"):
|
|
||||||
model = model.split("/", 1)[1]
|
|
||||||
|
|
||||||
potential_region = model.split(".", 1)[0]
|
|
||||||
|
|
||||||
alt_potential_region = model.split("/", 1)[
|
|
||||||
0
|
|
||||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
|
||||||
|
|
||||||
if potential_region in self._supported_cross_region_inference_region():
|
|
||||||
return model.split(".", 1)[1]
|
|
||||||
elif (
|
|
||||||
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
|
|
||||||
):
|
|
||||||
return model.split("/", 1)[1]
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
|
|
|
@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.anthropic.chat.handler import (
|
||||||
|
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -177,6 +180,7 @@ async def make_call(
|
||||||
logging_obj: Logging,
|
logging_obj: Logging,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
|
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if client is None:
|
if client is None:
|
||||||
|
@ -214,6 +218,14 @@ async def make_call(
|
||||||
completion_stream: Any = MockResponseIterator(
|
completion_stream: Any = MockResponseIterator(
|
||||||
model_response=model_response, json_mode=json_mode
|
model_response=model_response, json_mode=json_mode
|
||||||
)
|
)
|
||||||
|
elif bedrock_invoke_provider == "anthropic":
|
||||||
|
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||||
|
model=model,
|
||||||
|
sync_stream=False,
|
||||||
|
)
|
||||||
|
completion_stream = decoder.aiter_bytes(
|
||||||
|
response.aiter_bytes(chunk_size=1024)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
completion_stream = decoder.aiter_bytes(
|
completion_stream = decoder.aiter_bytes(
|
||||||
|
@ -248,6 +260,7 @@ def make_sync_call(
|
||||||
logging_obj: Logging,
|
logging_obj: Logging,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
|
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if client is None:
|
if client is None:
|
||||||
|
@ -283,6 +296,12 @@ def make_sync_call(
|
||||||
completion_stream: Any = MockResponseIterator(
|
completion_stream: Any = MockResponseIterator(
|
||||||
model_response=model_response, json_mode=json_mode
|
model_response=model_response, json_mode=json_mode
|
||||||
)
|
)
|
||||||
|
elif bedrock_invoke_provider == "anthropic":
|
||||||
|
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||||
|
model=model,
|
||||||
|
sync_stream=True,
|
||||||
|
)
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
else:
|
else:
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
|
||||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## bedrock.anthropic mappings ###############
|
######## converse bedrock.anthropic mappings ###############
|
||||||
elif (
|
elif (
|
||||||
"contentBlockIndex" in chunk_data
|
"contentBlockIndex" in chunk_data
|
||||||
or "stopReason" in chunk_data
|
or "stopReason" in chunk_data
|
||||||
|
@ -1429,6 +1448,27 @@ class AWSEventStreamDecoder:
|
||||||
return chunk.decode() # type: ignore[no-any-return]
|
return chunk.decode() # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
sync_stream: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
||||||
|
|
||||||
|
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
|
||||||
|
"""
|
||||||
|
super().__init__(model=model)
|
||||||
|
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||||
|
streaming_response=None,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
|
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
|
||||||
|
|
||||||
|
|
||||||
class MockResponseIterator: # for returning ai21 streaming responses
|
class MockResponseIterator: # for returning ai21 streaming responses
|
||||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||||
self.model_response = model_response
|
self.model_response = model_response
|
||||||
|
|
|
@ -1,61 +1,34 @@
|
||||||
import types
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
from typing import List, Optional
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||||
|
AmazonInvokeConfig,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
class AmazonAnthropicClaude3Config:
|
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
"""
|
"""
|
||||||
Reference:
|
Reference:
|
||||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||||
|
|
||||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||||
|
|
||||||
- `max_tokens` Required (integer) max tokens. Default is 4096
|
|
||||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
||||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
|
||||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
|
||||||
- `top_p` Optional (float) Use nucleus sampling.
|
|
||||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
|
||||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
anthropic_version: str = "bedrock-2023-05-31"
|
||||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
|
||||||
system: Optional[str] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
|
|
||||||
def __init__(
|
def get_supported_openai_params(self, model: str):
|
||||||
self,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
anthropic_version: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals().copy()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
return [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"max_completion_tokens",
|
"max_completion_tokens",
|
||||||
|
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
):
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
|
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
_anthropic_request = litellm.AnthropicConfig().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
_anthropic_request.pop("model", None)
|
||||||
|
if "anthropic_version" not in _anthropic_request:
|
||||||
|
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||||
|
|
||||||
|
return _anthropic_request
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
return litellm.AnthropicConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=raw_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data=request_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import copy
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import uuid
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||||
|
|
||||||
|
@ -13,11 +12,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
cohere_message_pt,
|
cohere_message_pt,
|
||||||
construct_tool_use_system_prompt,
|
|
||||||
contains_tag,
|
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
extract_between_tags,
|
|
||||||
parse_xml_params,
|
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
@ -194,7 +189,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
for k, v in inference_params.items()
|
for k, v in inference_params.items()
|
||||||
if k not in self.aws_authentication_params
|
if k not in self.aws_authentication_params
|
||||||
}
|
}
|
||||||
json_schemas: dict = {}
|
|
||||||
request_data: dict = {}
|
request_data: dict = {}
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
if model.startswith("cohere.command-r"):
|
if model.startswith("cohere.command-r"):
|
||||||
|
@ -223,57 +217,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
)
|
)
|
||||||
request_data = {"prompt": prompt, **inference_params}
|
request_data = {"prompt": prompt, **inference_params}
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||||
# Separate system prompt from rest of message
|
model=model,
|
||||||
system_prompt_idx: list[int] = []
|
messages=messages,
|
||||||
system_messages: list[str] = []
|
optional_params=optional_params,
|
||||||
for idx, message in enumerate(messages):
|
litellm_params=litellm_params,
|
||||||
if message["role"] == "system" and isinstance(
|
headers=headers,
|
||||||
message["content"], str
|
)
|
||||||
):
|
|
||||||
system_messages.append(message["content"])
|
|
||||||
system_prompt_idx.append(idx)
|
|
||||||
if len(system_prompt_idx) > 0:
|
|
||||||
inference_params["system"] = "\n".join(system_messages)
|
|
||||||
messages = [
|
|
||||||
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
|
||||||
]
|
|
||||||
# Format rest of message according to anthropic guidelines
|
|
||||||
messages = prompt_factory(
|
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
|
||||||
) # type: ignore
|
|
||||||
## LOAD CONFIG
|
|
||||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in inference_params
|
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
inference_params[k] = v
|
|
||||||
## Handle Tool Calling
|
|
||||||
if "tools" in inference_params:
|
|
||||||
_is_function_call = True
|
|
||||||
for tool in inference_params["tools"]:
|
|
||||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
|
||||||
"parameters", None
|
|
||||||
)
|
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
|
||||||
tools=inference_params["tools"]
|
|
||||||
)
|
|
||||||
inference_params["system"] = (
|
|
||||||
inference_params.get("system", "\n")
|
|
||||||
+ tool_calling_system_prompt
|
|
||||||
) # add the anthropic tool calling prompt to the system prompt
|
|
||||||
inference_params.pop("tools")
|
|
||||||
request_data = {"messages": messages, **inference_params}
|
|
||||||
else:
|
|
||||||
## LOAD CONFIG
|
|
||||||
config = litellm.AmazonAnthropicConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in inference_params
|
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
inference_params[k] = v
|
|
||||||
request_data = {"prompt": prompt, **inference_params}
|
|
||||||
elif provider == "ai21":
|
elif provider == "ai21":
|
||||||
## LOAD CONFIG
|
## LOAD CONFIG
|
||||||
config = litellm.AmazonAI21Config.get_config()
|
config = litellm.AmazonAI21Config.get_config()
|
||||||
|
@ -359,66 +309,19 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
completion_response["generations"][0]["finish_reason"]
|
completion_response["generations"][0]["finish_reason"]
|
||||||
)
|
)
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
return litellm.AmazonAnthropicClaude3Config().transform_response(
|
||||||
json_schemas: dict = {}
|
model=model,
|
||||||
_is_function_call = False
|
raw_response=raw_response,
|
||||||
## Handle Tool Calling
|
model_response=model_response,
|
||||||
if "tools" in optional_params:
|
logging_obj=logging_obj,
|
||||||
_is_function_call = True
|
request_data=request_data,
|
||||||
for tool in optional_params["tools"]:
|
messages=messages,
|
||||||
json_schemas[tool["function"]["name"]] = tool[
|
optional_params=optional_params,
|
||||||
"function"
|
litellm_params=litellm_params,
|
||||||
].get("parameters", None)
|
encoding=encoding,
|
||||||
outputText = completion_response.get("content")[0].get("text", None)
|
api_key=api_key,
|
||||||
if outputText is not None and contains_tag(
|
json_mode=json_mode,
|
||||||
"invoke", outputText
|
)
|
||||||
): # OUTPUT PARSE FUNCTION CALL
|
|
||||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
|
||||||
function_arguments_str = extract_between_tags(
|
|
||||||
"invoke", outputText
|
|
||||||
)[0].strip()
|
|
||||||
function_arguments_str = (
|
|
||||||
f"<invoke>{function_arguments_str}</invoke>"
|
|
||||||
)
|
|
||||||
function_arguments = parse_xml_params(
|
|
||||||
function_arguments_str,
|
|
||||||
json_schema=json_schemas.get(
|
|
||||||
function_name, None
|
|
||||||
), # check if we have a json schema for this function name)
|
|
||||||
)
|
|
||||||
_message = litellm.Message(
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{uuid.uuid4()}",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_name,
|
|
||||||
"arguments": json.dumps(function_arguments),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
content=None,
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = _message # type: ignore
|
|
||||||
model_response._hidden_params["original_response"] = (
|
|
||||||
outputText # allow user to access raw anthropic tool calling response
|
|
||||||
)
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
|
||||||
completion_response.get("stop_reason", "")
|
|
||||||
)
|
|
||||||
_usage = litellm.Usage(
|
|
||||||
prompt_tokens=completion_response["usage"]["input_tokens"],
|
|
||||||
completion_tokens=completion_response["usage"]["output_tokens"],
|
|
||||||
total_tokens=completion_response["usage"]["input_tokens"]
|
|
||||||
+ completion_response["usage"]["output_tokens"],
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", _usage)
|
|
||||||
else:
|
|
||||||
outputText = completion_response["completion"]
|
|
||||||
|
|
||||||
model_response.choices[0].finish_reason = completion_response[
|
|
||||||
"stop_reason"
|
|
||||||
]
|
|
||||||
elif provider == "ai21":
|
elif provider == "ai21":
|
||||||
outputText = (
|
outputText = (
|
||||||
completion_response.get("completions")[0].get("data").get("text")
|
completion_response.get("completions")[0].get("data").get("text")
|
||||||
|
@ -536,6 +439,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
fake_stream=True if "ai21" in api_base else False,
|
fake_stream=True if "ai21" in api_base else False,
|
||||||
|
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
|
@ -569,6 +473,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
fake_stream=True if "ai21" in api_base else False,
|
fake_stream=True if "ai21" in api_base else False,
|
||||||
|
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
|
@ -594,10 +499,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
Helper function to get the bedrock provider from the model
|
Helper function to get the bedrock provider from the model
|
||||||
|
|
||||||
handles 2 scenarions:
|
handles 3 scenarions:
|
||||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||||
"""
|
"""
|
||||||
|
if model.startswith("invoke/"):
|
||||||
|
model = model.replace("invoke/", "", 1)
|
||||||
|
|
||||||
_split_model = model.split(".")[0]
|
_split_model = model.split(".")[0]
|
||||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||||
|
@ -640,9 +549,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
else:
|
else:
|
||||||
modelId = model
|
modelId = model
|
||||||
|
|
||||||
|
modelId = modelId.replace("invoke/", "", 1)
|
||||||
if provider == "llama" and "llama/" in modelId:
|
if provider == "llama" and "llama/" in modelId:
|
||||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||||
|
|
||||||
return modelId
|
return modelId
|
||||||
|
|
||||||
def _get_aws_region_name(self, optional_params: dict) -> str:
|
def _get_aws_region_name(self, optional_params: dict) -> str:
|
||||||
|
|
|
@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
|
||||||
|
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
|
||||||
response_tool_name
|
response_tool_name
|
||||||
]
|
]
|
||||||
return response_tool_name
|
return response_tool_name
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
|
|
||||||
|
global_config = AmazonBedrockGlobalConfig()
|
||||||
|
all_global_regions = global_config.get_all_regions()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the base model from the given model name.
|
||||||
|
|
||||||
|
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
"""
|
||||||
|
if model.startswith("bedrock/"):
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
if model.startswith("converse/"):
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
if model.startswith("invoke/"):
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
potential_region = model.split(".", 1)[0]
|
||||||
|
|
||||||
|
alt_potential_region = model.split("/", 1)[
|
||||||
|
0
|
||||||
|
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||||
|
|
||||||
|
if (
|
||||||
|
potential_region
|
||||||
|
in BedrockModelInfo._supported_cross_region_inference_region()
|
||||||
|
):
|
||||||
|
return model.split(".", 1)[1]
|
||||||
|
elif (
|
||||||
|
alt_potential_region in BedrockModelInfo.all_global_regions
|
||||||
|
and len(model.split("/", 1)) > 1
|
||||||
|
):
|
||||||
|
return model.split("/", 1)[1]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supported_cross_region_inference_region() -> List[str]:
|
||||||
|
"""
|
||||||
|
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||||
|
"""
|
||||||
|
return ["us", "eu", "apac"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
|
||||||
|
"""
|
||||||
|
Get the bedrock route for the given model.
|
||||||
|
"""
|
||||||
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
|
if "invoke/" in model:
|
||||||
|
return "invoke"
|
||||||
|
elif "converse_like" in model:
|
||||||
|
return "converse_like"
|
||||||
|
elif "converse/" in model:
|
||||||
|
return "converse"
|
||||||
|
elif base_model in litellm.bedrock_converse_models:
|
||||||
|
return "converse"
|
||||||
|
return "invoke"
|
||||||
|
|
|
@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
or "https://api.openai.com/v1"
|
or "https://api.openai.com/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> str:
|
||||||
|
return model
|
||||||
|
|
||||||
def get_model_response_iterator(
|
def get_model_response_iterator(
|
||||||
self,
|
self,
|
||||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
|
|
@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
|
||||||
return (
|
return (
|
||||||
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
|
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> str:
|
||||||
|
return model
|
||||||
|
|
|
@ -68,6 +68,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
get_content_from_model_response,
|
get_content_from_model_response,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.realtime_api.main import _realtime_health_check
|
from litellm.realtime_api.main import _realtime_health_check
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
aws_bedrock_client.meta.region_name
|
aws_bedrock_client.meta.region_name
|
||||||
)
|
)
|
||||||
|
|
||||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
|
if bedrock_route == "converse":
|
||||||
if base_model in litellm.bedrock_converse_models or model.startswith(
|
|
||||||
"converse/"
|
|
||||||
):
|
|
||||||
model = model.replace("converse/", "")
|
model = model.replace("converse/", "")
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
elif "converse_like" in model:
|
elif bedrock_route == "converse_like":
|
||||||
model = model.replace("converse_like/", "")
|
model = model.replace("converse_like/", "")
|
||||||
response = base_llm_http_handler.completion(
|
response = base_llm_http_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
|
||||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||||
LiteLLMResponseObjectHandler,
|
LiteLLMResponseObjectHandler,
|
||||||
_handle_invalid_parallel_tool_calls,
|
_handle_invalid_parallel_tool_calls,
|
||||||
|
_parse_content_for_reasoning,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
convert_to_streaming_response,
|
convert_to_streaming_response,
|
||||||
convert_to_streaming_response_async,
|
convert_to_streaming_response_async,
|
||||||
_parse_content_for_reasoning,
|
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||||
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
||||||
|
@ -111,6 +111,7 @@ from litellm.litellm_core_utils.token_counter import (
|
||||||
calculate_img_tokens,
|
calculate_img_tokens,
|
||||||
get_modified_max_tokens,
|
get_modified_max_tokens,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.router_utils.get_retry_from_policy import (
|
from litellm.router_utils.get_retry_from_policy import (
|
||||||
get_num_retries_from_retry_policy,
|
get_num_retries_from_retry_policy,
|
||||||
|
@ -3189,8 +3190,8 @@ def get_optional_params( # noqa: PLR0915
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
if base_model in litellm.bedrock_converse_models:
|
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
@ -3203,15 +3204,20 @@ def get_optional_params( # noqa: PLR0915
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "anthropic" in model:
|
elif "anthropic" in model and bedrock_route == "invoke":
|
||||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
if model.startswith("anthropic.claude-3"):
|
||||||
if model.startswith("anthropic.claude-3"):
|
optional_params = (
|
||||||
optional_params = (
|
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
non_default_params=non_default_params,
|
||||||
non_default_params=non_default_params,
|
optional_params=optional_params,
|
||||||
optional_params=optional_params,
|
model=model,
|
||||||
)
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str:
|
||||||
return re.sub(r"-\d+$", "", model_name)
|
return re.sub(r"-\d+$", "", model_name)
|
||||||
|
|
||||||
|
|
||||||
def _strip_bedrock_region(model_name) -> str:
|
def _get_base_bedrock_model(model_name) -> str:
|
||||||
return litellm.AmazonConverseConfig()._get_base_model(model_name)
|
"""
|
||||||
|
Get the base model from the given model name.
|
||||||
|
|
||||||
|
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
"""
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
|
return BedrockModelInfo.get_base_model(model_name)
|
||||||
|
|
||||||
|
|
||||||
def _strip_openai_finetune_model_name(model_name: str) -> str:
|
def _strip_openai_finetune_model_name(model_name: str) -> str:
|
||||||
|
@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
|
||||||
|
|
||||||
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
|
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
|
||||||
if custom_llm_provider and custom_llm_provider == "bedrock":
|
if custom_llm_provider and custom_llm_provider == "bedrock":
|
||||||
strip_bedrock_region = _strip_bedrock_region(model_name=model)
|
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
|
||||||
return strip_bedrock_region
|
return stripped_bedrock_model
|
||||||
elif custom_llm_provider and (
|
elif custom_llm_provider and (
|
||||||
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
|
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
|
||||||
):
|
):
|
||||||
|
@ -6066,24 +6080,23 @@ class ProviderConfigManager:
|
||||||
elif litellm.LlmProviders.PETALS == provider:
|
elif litellm.LlmProviders.PETALS == provider:
|
||||||
return litellm.PetalsConfig()
|
return litellm.PetalsConfig()
|
||||||
elif litellm.LlmProviders.BEDROCK == provider:
|
elif litellm.LlmProviders.BEDROCK == provider:
|
||||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
|
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
||||||
if (
|
model
|
||||||
base_model in litellm.bedrock_converse_models
|
)
|
||||||
or "converse_like" in model
|
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||||
):
|
|
||||||
return litellm.AmazonConverseConfig()
|
return litellm.AmazonConverseConfig()
|
||||||
elif bedrock_provider == "amazon": # amazon titan llms
|
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
||||||
return litellm.AmazonTitanConfig()
|
return litellm.AmazonTitanConfig()
|
||||||
elif (
|
elif (
|
||||||
bedrock_provider == "meta" or bedrock_provider == "llama"
|
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
||||||
): # amazon / meta llms
|
): # amazon / meta llms
|
||||||
return litellm.AmazonLlamaConfig()
|
return litellm.AmazonLlamaConfig()
|
||||||
elif bedrock_provider == "ai21": # ai21 llms
|
elif bedrock_invoke_provider == "ai21": # ai21 llms
|
||||||
return litellm.AmazonAI21Config()
|
return litellm.AmazonAI21Config()
|
||||||
elif bedrock_provider == "cohere": # cohere models on bedrock
|
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
|
||||||
return litellm.AmazonCohereConfig()
|
return litellm.AmazonCohereConfig()
|
||||||
elif bedrock_provider == "mistral": # mistral models on bedrock
|
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
|
||||||
return litellm.AmazonMistralConfig()
|
return litellm.AmazonMistralConfig()
|
||||||
else:
|
else:
|
||||||
return litellm.AmazonInvokeConfig()
|
return litellm.AmazonInvokeConfig()
|
||||||
|
|
|
@ -1265,7 +1265,9 @@ def test_bedrock_cross_region_inference(model):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_bedrock_get_base_model(model, expected_base_model):
|
def test_bedrock_get_base_model(model, expected_base_model):
|
||||||
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
|
assert BedrockModelInfo.get_base_model(model) == expected_base_model
|
||||||
|
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
@ -1982,9 +1984,49 @@ def test_bedrock_mapped_converse_models():
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_base_model_helper():
|
def test_bedrock_base_model_helper():
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
model = "us.amazon.nova-pro-v1:0"
|
model = "us.amazon.nova-pro-v1:0"
|
||||||
litellm.AmazonConverseConfig()._get_base_model(model)
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
assert model == "us.amazon.nova-pro-v1:0"
|
assert base_model == "amazon.nova-pro-v1:0"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
BedrockModelInfo.get_base_model(
|
||||||
|
"invoke/anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
)
|
||||||
|
== "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,expected_route",
|
||||||
|
[
|
||||||
|
# Test explicit route prefixes
|
||||||
|
("invoke/anthropic.claude-3-sonnet-20240229-v1:0", "invoke"),
|
||||||
|
("converse/anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
|
||||||
|
("converse_like/anthropic.claude-3-sonnet-20240229-v1:0", "converse_like"),
|
||||||
|
# Test models in BEDROCK_CONVERSE_MODELS list
|
||||||
|
("anthropic.claude-3-5-haiku-20241022-v1:0", "converse"),
|
||||||
|
("anthropic.claude-v2", "converse"),
|
||||||
|
("meta.llama3-70b-instruct-v1:0", "converse"),
|
||||||
|
("mistral.mistral-large-2407-v1:0", "converse"),
|
||||||
|
# Test models with region prefixes
|
||||||
|
("us.anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
|
||||||
|
("us.meta.llama3-70b-instruct-v1:0", "converse"),
|
||||||
|
# Test default case (should return "invoke")
|
||||||
|
("amazon.titan-text-express-v1", "invoke"),
|
||||||
|
("cohere.command-text-v14", "invoke"),
|
||||||
|
("cohere.command-r-v1:0", "invoke"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bedrock_route_detection(model, expected_route):
|
||||||
|
"""Test all scenarios for BedrockModelInfo.get_bedrock_route"""
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
|
route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
|
assert (
|
||||||
|
route == expected_route
|
||||||
|
), f"Expected route '{expected_route}' for model '{model}', but got '{route}'"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
28
tests/llm_translation/test_bedrock_invoke_claude_json.py
Normal file
28
tests/llm_translation/test_bedrock_invoke_claude_json.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def skip_non_json_tests(self, request):
|
||||||
|
if not "json" in request.function.__name__.lower():
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||||
|
)
|
|
@ -305,12 +305,16 @@ def test_all_model_configs():
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"max_completion_tokens"
|
"max_completion_tokens"
|
||||||
in AmazonAnthropicClaude3Config().get_supported_openai_params()
|
in AmazonAnthropicClaude3Config().get_supported_openai_params(
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert AmazonAnthropicClaude3Config().map_openai_params(
|
assert AmazonAnthropicClaude3Config().map_openai_params(
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -208,3 +208,11 @@ def test_nova_bedrock_converse():
|
||||||
)
|
)
|
||||||
assert custom_llm_provider == "bedrock"
|
assert custom_llm_provider == "bedrock"
|
||||||
assert model == "amazon.nova-micro-v1:0"
|
assert model == "amazon.nova-micro-v1:0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_invoke_anthropic():
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "bedrock"
|
||||||
|
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
|
|
@ -321,7 +321,7 @@ def test_get_model_info_bedrock_models():
|
||||||
"""
|
"""
|
||||||
Check for drift in base model info for bedrock models and regional model info for bedrock models.
|
Check for drift in base model info for bedrock models and regional model info for bedrock models.
|
||||||
"""
|
"""
|
||||||
from litellm import AmazonConverseConfig
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
@ -337,7 +337,7 @@ def test_get_model_info_bedrock_models():
|
||||||
if any(commitment in k for commitment in potential_commitments):
|
if any(commitment in k for commitment in potential_commitments):
|
||||||
for commitment in potential_commitments:
|
for commitment in potential_commitments:
|
||||||
k = k.replace(f"{commitment}/", "")
|
k = k.replace(f"{commitment}/", "")
|
||||||
base_model = AmazonConverseConfig()._get_base_model(k)
|
base_model = BedrockModelInfo.get_base_model(k)
|
||||||
base_model_info = litellm.model_cost[base_model]
|
base_model_info = litellm.model_cost[base_model]
|
||||||
for base_model_key, base_model_value in base_model_info.items():
|
for base_model_key, base_model_value in base_model_info.items():
|
||||||
if base_model_key.startswith("supports_"):
|
if base_model_key.startswith("supports_"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue