diff --git a/litellm/__init__.py b/litellm/__init__.py index c49b3214b9..b8de8a4298 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -360,7 +360,7 @@ BEDROCK_CONVERSE_MODELS = [ "meta.llama3-2-90b-instruct-v1:0", ] BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[ - "cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21" + "cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21", "nova" ] ####### COMPLETION MODELS ################### open_ai_chat_completion_models: List = [] @@ -863,6 +863,9 @@ from .llms.bedrock.common_utils import ( from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import ( AmazonAI21Config, ) +from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import ( + AmazonInvokeNovaConfig, +) from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import ( AmazonAnthropicConfig, ) diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index db419aa110..43fdc061e7 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -1342,7 +1342,7 @@ class AWSEventStreamDecoder: text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore is_finished = True finish_reason = "stop" - ######## converse bedrock.anthropic mappings ############### + ######## /bedrock/converse mappings ############### elif ( "contentBlockIndex" in chunk_data or "stopReason" in chunk_data @@ -1350,6 +1350,11 @@ class AWSEventStreamDecoder: or "trace" in chunk_data ): return self.converse_chunk_parser(chunk_data=chunk_data) + ######### /bedrock/invoke nova mappings ############### + elif "contentBlockDelta" in chunk_data: + # when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta" + _chunk_data = chunk_data.get("contentBlockDelta", None) + return self.converse_chunk_parser(chunk_data=_chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: if ( diff --git a/litellm/llms/bedrock/chat/invoke_transformations/amazon_nova_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/amazon_nova_transformation.py new file mode 100644 index 0000000000..9d41beceff --- /dev/null +++ b/litellm/llms/bedrock/chat/invoke_transformations/amazon_nova_transformation.py @@ -0,0 +1,70 @@ +""" +Handles transforming requests for `bedrock/invoke/{nova} models` + +Inherits from `AmazonConverseConfig` + +Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html +""" + +from typing import List + +import litellm +from litellm.types.llms.bedrock import BedrockInvokeNovaRequest +from litellm.types.llms.openai import AllMessageValues + + +class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig): + """ + Config for sending `nova` requests to `/bedrock/invoke/` + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + _transformed_nova_request = super().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + _bedrock_invoke_nova_request = BedrockInvokeNovaRequest( + **_transformed_nova_request + ) + self._remove_empty_system_messages(_bedrock_invoke_nova_request) + bedrock_invoke_nova_request = self._filter_allowed_fields( + _bedrock_invoke_nova_request + ) + return bedrock_invoke_nova_request + + def _filter_allowed_fields( + self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest + ) -> dict: + """ + Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass. + """ + allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys()) + return { + k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields + } + + def _remove_empty_system_messages( + self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest + ) -> None: + """ + In-place remove empty `system` messages from the request. + + /bedrock/invoke/ does not allow empty `system` messages. + """ + _system_message = bedrock_invoke_nova_request.get("system", None) + if isinstance(_system_message, list) and len(_system_message) == 0: + bedrock_invoke_nova_request.pop("system", None) + return diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index 6284a7ab08..66ef1296c8 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_a import httpx import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.prompt_templates.factory import ( @@ -166,7 +167,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): return dict(request.headers) - def transform_request( # noqa: PLR0915 + def transform_request( self, model: str, messages: List[AllMessageValues], @@ -224,6 +225,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): litellm_params=litellm_params, headers=headers, ) + elif provider == "nova": + return litellm.AmazonInvokeNovaConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) elif provider == "ai21": ## LOAD CONFIG config = litellm.AmazonAI21Config.get_config() @@ -297,6 +306,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): raise BedrockError( message=raw_response.text, status_code=raw_response.status_code ) + verbose_logger.debug( + "bedrock invoke response % s", + json.dumps(completion_response, indent=4, default=str), + ) provider = self.get_bedrock_invoke_provider(model) outputText: Optional[str] = None try: @@ -322,6 +335,18 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): api_key=api_key, json_mode=json_mode, ) + elif provider == "nova": + return litellm.AmazonInvokeNovaConfig().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) elif provider == "ai21": outputText = ( completion_response.get("completions")[0].get("data").get("text") @@ -503,6 +528,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` + 4. model=us.amazon.nova-pro-v1:0 -> Returns `nova` """ if model.startswith("invoke/"): model = model.replace("invoke/", "", 1) @@ -515,6 +541,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): provider = AmazonInvokeConfig._get_provider_from_model_path(model) if provider is not None: return provider + + # check if provider == "nova" + if "nova" in model: + return "nova" return None @staticmethod diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index cc3a27cda7..70b5769185 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -184,6 +184,18 @@ class RequestObject(CommonRequestObject, total=False): messages: Required[List[MessageBlock]] +class BedrockInvokeNovaRequest(TypedDict, total=False): + """ + Request object for sending `nova` requests to `/bedrock/invoke/` + """ + + messages: List[MessageBlock] + inferenceConfig: InferenceConfig + system: List[SystemContentBlock] + toolConfig: ToolConfigBlock + guardrailConfig: Optional[GuardrailConfigBlock] + + class GenericStreamingChunk(TypedDict): text: Required[str] tool_use: Optional[ChatCompletionToolCallChunk] diff --git a/tests/llm_translation/test_bedrock_invoke_claude_json.py b/tests/llm_translation/test_bedrock_invoke_claude_json.py deleted file mode 100644 index 2e943ed682..0000000000 --- a/tests/llm_translation/test_bedrock_invoke_claude_json.py +++ /dev/null @@ -1,28 +0,0 @@ -from base_llm_unit_tests import BaseLLMChatTest -import pytest -import sys -import os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import litellm - - -class TestBedrockInvokeClaudeJson(BaseLLMChatTest): - def get_base_completion_call_args(self) -> dict: - litellm._turn_on_debug() - return { - "model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0", - } - - def test_tool_call_no_arguments(self, tool_call_no_arguments): - """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" - pass - - @pytest.fixture(autouse=True) - def skip_non_json_tests(self, request): - if not "json" in request.function.__name__.lower(): - pytest.skip( - f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" - ) diff --git a/tests/llm_translation/test_bedrock_invoke_tests.py b/tests/llm_translation/test_bedrock_invoke_tests.py new file mode 100644 index 0000000000..ca12ee0492 --- /dev/null +++ b/tests/llm_translation/test_bedrock_invoke_tests.py @@ -0,0 +1,153 @@ +from base_llm_unit_tests import BaseLLMChatTest +import pytest +import sys +import os + + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.types.llms.bedrock import BedrockInvokeNovaRequest + + +class TestBedrockInvokeClaudeJson(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + litellm._turn_on_debug() + return { + "model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass + + @pytest.fixture(autouse=True) + def skip_non_json_tests(self, request): + if not "json" in request.function.__name__.lower(): + pytest.skip( + f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" + ) + + +class TestBedrockInvokeNovaJson(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + litellm._turn_on_debug() + return { + "model": "bedrock/invoke/us.amazon.nova-micro-v1:0", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass + + @pytest.fixture(autouse=True) + def skip_non_json_tests(self, request): + if not "json" in request.function.__name__.lower(): + pytest.skip( + f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" + ) + + +def test_nova_invoke_remove_empty_system_messages(): + """Test that _remove_empty_system_messages removes empty system list.""" + input_request = BedrockInvokeNovaRequest( + messages=[{"content": [{"text": "Hello"}], "role": "user"}], + system=[], + inferenceConfig={"temperature": 0.7}, + ) + + litellm.AmazonInvokeNovaConfig()._remove_empty_system_messages(input_request) + + assert "system" not in input_request + assert "messages" in input_request + assert "inferenceConfig" in input_request + + +def test_nova_invoke_filter_allowed_fields(): + """ + Test that _filter_allowed_fields only keeps fields defined in BedrockInvokeNovaRequest. + + Nova Invoke does not allow `additionalModelRequestFields` and `additionalModelResponseFieldPaths` in the request body. + This test ensures that these fields are not included in the request body. + """ + _input_request = { + "messages": [{"content": [{"text": "Hello"}], "role": "user"}], + "system": [{"text": "System prompt"}], + "inferenceConfig": {"temperature": 0.7}, + "additionalModelRequestFields": {"this": "should be removed"}, + "additionalModelResponseFieldPaths": ["this", "should", "be", "removed"], + } + + input_request = BedrockInvokeNovaRequest(**_input_request) + + result = litellm.AmazonInvokeNovaConfig()._filter_allowed_fields(input_request) + + assert "additionalModelRequestFields" not in result + assert "additionalModelResponseFieldPaths" not in result + assert "messages" in result + assert "system" in result + assert "inferenceConfig" in result + + +def test_nova_invoke_streaming_chunk_parsing(): + """ + Test that the AWSEventStreamDecoder correctly handles Nova's /bedrock/invoke/ streaming format + where content is nested under 'contentBlockDelta'. + """ + from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder + + # Initialize the decoder with a Nova model + decoder = AWSEventStreamDecoder(model="bedrock/invoke/us.amazon.nova-micro-v1:0") + + # Test case 1: Text content in contentBlockDelta + nova_text_chunk = { + "contentBlockDelta": { + "delta": {"text": "Hello, how can I help?"}, + "contentBlockIndex": 0, + } + } + result = decoder._chunk_parser(nova_text_chunk) + assert result["text"] == "Hello, how can I help?" + assert result["index"] == 0 + assert not result["is_finished"] + assert result["tool_use"] is None + + # Test case 2: Tool use start in contentBlockDelta + nova_tool_start_chunk = { + "contentBlockDelta": { + "start": {"toolUse": {"name": "get_weather", "toolUseId": "tool_1"}}, + "contentBlockIndex": 1, + } + } + result = decoder._chunk_parser(nova_tool_start_chunk) + assert result["text"] == "" + assert result["index"] == 1 + assert result["tool_use"] is not None + assert result["tool_use"]["type"] == "function" + assert result["tool_use"]["function"]["name"] == "get_weather" + assert result["tool_use"]["id"] == "tool_1" + + # Test case 3: Tool use arguments in contentBlockDelta + nova_tool_args_chunk = { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"location": "New York"}'}}, + "contentBlockIndex": 2, + } + } + result = decoder._chunk_parser(nova_tool_args_chunk) + assert result["text"] == "" + assert result["index"] == 2 + assert result["tool_use"] is not None + assert result["tool_use"]["function"]["arguments"] == '{"location": "New York"}' + + # Test case 4: Stop reason in contentBlockDelta + nova_stop_chunk = { + "contentBlockDelta": { + "stopReason": "tool_use", + } + } + result = decoder._chunk_parser(nova_stop_chunk) + print(result) + assert result["is_finished"] is True + assert result["finish_reason"] == "tool_calls"