mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(Feat) - Allow calling Nova models on /bedrock/invoke/
(#8397)
* add nova to BEDROCK_INVOKE_PROVIDERS_LITERAL * BedrockInvokeNovaRequest * nova + invoke config * add AmazonInvokeNovaConfig * AmazonInvokeNovaConfig * run transform_request for invoke/nova models * AmazonInvokeNovaConfig * rename invoke tests * fix linting error * TestBedrockInvokeNovaJson * TestBedrockInvokeNovaJson * add converse_chunk_parser * test_nova_invoke_remove_empty_system_messages * test_nova_invoke_streaming_chunk_parsing
This commit is contained in:
parent
fc01b304a1
commit
0d9e641034
7 changed files with 276 additions and 31 deletions
|
@ -360,7 +360,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-2-90b-instruct-v1:0",
|
"meta.llama3-2-90b-instruct-v1:0",
|
||||||
]
|
]
|
||||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||||
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21"
|
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21", "nova"
|
||||||
]
|
]
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
|
@ -863,6 +863,9 @@ from .llms.bedrock.common_utils import (
|
||||||
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
|
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
)
|
)
|
||||||
|
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
|
||||||
|
AmazonInvokeNovaConfig,
|
||||||
|
)
|
||||||
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
|
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
|
||||||
AmazonAnthropicConfig,
|
AmazonAnthropicConfig,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1342,7 +1342,7 @@ class AWSEventStreamDecoder:
|
||||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## converse bedrock.anthropic mappings ###############
|
######## /bedrock/converse mappings ###############
|
||||||
elif (
|
elif (
|
||||||
"contentBlockIndex" in chunk_data
|
"contentBlockIndex" in chunk_data
|
||||||
or "stopReason" in chunk_data
|
or "stopReason" in chunk_data
|
||||||
|
@ -1350,6 +1350,11 @@ class AWSEventStreamDecoder:
|
||||||
or "trace" in chunk_data
|
or "trace" in chunk_data
|
||||||
):
|
):
|
||||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
|
######### /bedrock/invoke nova mappings ###############
|
||||||
|
elif "contentBlockDelta" in chunk_data:
|
||||||
|
# when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta"
|
||||||
|
_chunk_data = chunk_data.get("contentBlockDelta", None)
|
||||||
|
return self.converse_chunk_parser(chunk_data=_chunk_data)
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
elif "outputs" in chunk_data:
|
elif "outputs" in chunk_data:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""
|
||||||
|
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||||
|
|
||||||
|
Inherits from `AmazonConverseConfig`
|
||||||
|
|
||||||
|
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||||
|
"""
|
||||||
|
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
_transformed_nova_request = super().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||||
|
**_transformed_nova_request
|
||||||
|
)
|
||||||
|
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||||
|
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||||
|
_bedrock_invoke_nova_request
|
||||||
|
)
|
||||||
|
return bedrock_invoke_nova_request
|
||||||
|
|
||||||
|
def _filter_allowed_fields(
|
||||||
|
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||||
|
"""
|
||||||
|
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||||
|
return {
|
||||||
|
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
def _remove_empty_system_messages(
|
||||||
|
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
In-place remove empty `system` messages from the request.
|
||||||
|
|
||||||
|
/bedrock/invoke/ does not allow empty `system` messages.
|
||||||
|
"""
|
||||||
|
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||||
|
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||||
|
bedrock_invoke_nova_request.pop("system", None)
|
||||||
|
return
|
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_a
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
@ -166,7 +167,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
|
|
||||||
return dict(request.headers)
|
return dict(request.headers)
|
||||||
|
|
||||||
def transform_request( # noqa: PLR0915
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
@ -224,6 +225,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
elif provider == "nova":
|
||||||
|
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
elif provider == "ai21":
|
elif provider == "ai21":
|
||||||
## LOAD CONFIG
|
## LOAD CONFIG
|
||||||
config = litellm.AmazonAI21Config.get_config()
|
config = litellm.AmazonAI21Config.get_config()
|
||||||
|
@ -297,6 +306,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
message=raw_response.text, status_code=raw_response.status_code
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
)
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"bedrock invoke response % s",
|
||||||
|
json.dumps(completion_response, indent=4, default=str),
|
||||||
|
)
|
||||||
provider = self.get_bedrock_invoke_provider(model)
|
provider = self.get_bedrock_invoke_provider(model)
|
||||||
outputText: Optional[str] = None
|
outputText: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
|
@ -322,6 +335,18 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
json_mode=json_mode,
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
elif provider == "nova":
|
||||||
|
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=raw_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data=request_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
elif provider == "ai21":
|
elif provider == "ai21":
|
||||||
outputText = (
|
outputText = (
|
||||||
completion_response.get("completions")[0].get("data").get("text")
|
completion_response.get("completions")[0].get("data").get("text")
|
||||||
|
@ -503,6 +528,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||||
|
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||||
"""
|
"""
|
||||||
if model.startswith("invoke/"):
|
if model.startswith("invoke/"):
|
||||||
model = model.replace("invoke/", "", 1)
|
model = model.replace("invoke/", "", 1)
|
||||||
|
@ -515,6 +541,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||||
if provider is not None:
|
if provider is not None:
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
# check if provider == "nova"
|
||||||
|
if "nova" in model:
|
||||||
|
return "nova"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -184,6 +184,18 @@ class RequestObject(CommonRequestObject, total=False):
|
||||||
messages: Required[List[MessageBlock]]
|
messages: Required[List[MessageBlock]]
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockInvokeNovaRequest(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Request object for sending `nova` requests to `/bedrock/invoke/`
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: List[MessageBlock]
|
||||||
|
inferenceConfig: InferenceConfig
|
||||||
|
system: List[SystemContentBlock]
|
||||||
|
toolConfig: ToolConfigBlock
|
||||||
|
guardrailConfig: Optional[GuardrailConfigBlock]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
|
||||||
def get_base_completion_call_args(self) -> dict:
|
|
||||||
litellm._turn_on_debug()
|
|
||||||
return {
|
|
||||||
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def skip_non_json_tests(self, request):
|
|
||||||
if not "json" in request.function.__name__.lower():
|
|
||||||
pytest.skip(
|
|
||||||
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
|
||||||
)
|
|
153
tests/llm_translation/test_bedrock_invoke_tests.py
Normal file
153
tests/llm_translation/test_bedrock_invoke_tests.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def skip_non_json_tests(self, request):
|
||||||
|
if not "json" in request.function.__name__.lower():
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def skip_non_json_tests(self, request):
|
||||||
|
if not "json" in request.function.__name__.lower():
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nova_invoke_remove_empty_system_messages():
|
||||||
|
"""Test that _remove_empty_system_messages removes empty system list."""
|
||||||
|
input_request = BedrockInvokeNovaRequest(
|
||||||
|
messages=[{"content": [{"text": "Hello"}], "role": "user"}],
|
||||||
|
system=[],
|
||||||
|
inferenceConfig={"temperature": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.AmazonInvokeNovaConfig()._remove_empty_system_messages(input_request)
|
||||||
|
|
||||||
|
assert "system" not in input_request
|
||||||
|
assert "messages" in input_request
|
||||||
|
assert "inferenceConfig" in input_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_nova_invoke_filter_allowed_fields():
|
||||||
|
"""
|
||||||
|
Test that _filter_allowed_fields only keeps fields defined in BedrockInvokeNovaRequest.
|
||||||
|
|
||||||
|
Nova Invoke does not allow `additionalModelRequestFields` and `additionalModelResponseFieldPaths` in the request body.
|
||||||
|
This test ensures that these fields are not included in the request body.
|
||||||
|
"""
|
||||||
|
_input_request = {
|
||||||
|
"messages": [{"content": [{"text": "Hello"}], "role": "user"}],
|
||||||
|
"system": [{"text": "System prompt"}],
|
||||||
|
"inferenceConfig": {"temperature": 0.7},
|
||||||
|
"additionalModelRequestFields": {"this": "should be removed"},
|
||||||
|
"additionalModelResponseFieldPaths": ["this", "should", "be", "removed"],
|
||||||
|
}
|
||||||
|
|
||||||
|
input_request = BedrockInvokeNovaRequest(**_input_request)
|
||||||
|
|
||||||
|
result = litellm.AmazonInvokeNovaConfig()._filter_allowed_fields(input_request)
|
||||||
|
|
||||||
|
assert "additionalModelRequestFields" not in result
|
||||||
|
assert "additionalModelResponseFieldPaths" not in result
|
||||||
|
assert "messages" in result
|
||||||
|
assert "system" in result
|
||||||
|
assert "inferenceConfig" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_nova_invoke_streaming_chunk_parsing():
|
||||||
|
"""
|
||||||
|
Test that the AWSEventStreamDecoder correctly handles Nova's /bedrock/invoke/ streaming format
|
||||||
|
where content is nested under 'contentBlockDelta'.
|
||||||
|
"""
|
||||||
|
from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder
|
||||||
|
|
||||||
|
# Initialize the decoder with a Nova model
|
||||||
|
decoder = AWSEventStreamDecoder(model="bedrock/invoke/us.amazon.nova-micro-v1:0")
|
||||||
|
|
||||||
|
# Test case 1: Text content in contentBlockDelta
|
||||||
|
nova_text_chunk = {
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"delta": {"text": "Hello, how can I help?"},
|
||||||
|
"contentBlockIndex": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = decoder._chunk_parser(nova_text_chunk)
|
||||||
|
assert result["text"] == "Hello, how can I help?"
|
||||||
|
assert result["index"] == 0
|
||||||
|
assert not result["is_finished"]
|
||||||
|
assert result["tool_use"] is None
|
||||||
|
|
||||||
|
# Test case 2: Tool use start in contentBlockDelta
|
||||||
|
nova_tool_start_chunk = {
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"start": {"toolUse": {"name": "get_weather", "toolUseId": "tool_1"}},
|
||||||
|
"contentBlockIndex": 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = decoder._chunk_parser(nova_tool_start_chunk)
|
||||||
|
assert result["text"] == ""
|
||||||
|
assert result["index"] == 1
|
||||||
|
assert result["tool_use"] is not None
|
||||||
|
assert result["tool_use"]["type"] == "function"
|
||||||
|
assert result["tool_use"]["function"]["name"] == "get_weather"
|
||||||
|
assert result["tool_use"]["id"] == "tool_1"
|
||||||
|
|
||||||
|
# Test case 3: Tool use arguments in contentBlockDelta
|
||||||
|
nova_tool_args_chunk = {
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"delta": {"toolUse": {"input": '{"location": "New York"}'}},
|
||||||
|
"contentBlockIndex": 2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = decoder._chunk_parser(nova_tool_args_chunk)
|
||||||
|
assert result["text"] == ""
|
||||||
|
assert result["index"] == 2
|
||||||
|
assert result["tool_use"] is not None
|
||||||
|
assert result["tool_use"]["function"]["arguments"] == '{"location": "New York"}'
|
||||||
|
|
||||||
|
# Test case 4: Stop reason in contentBlockDelta
|
||||||
|
nova_stop_chunk = {
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"stopReason": "tool_use",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = decoder._chunk_parser(nova_stop_chunk)
|
||||||
|
print(result)
|
||||||
|
assert result["is_finished"] is True
|
||||||
|
assert result["finish_reason"] == "tool_calls"
|
Loading…
Add table
Add a link
Reference in a new issue