Litellm dev 01 29 2025 p2 (#8102)

* docs: cleanup doc

* feat(bedrock/): initial commit adding bedrock/converse_like/<model> route support

allows routing to a converse like endpoint

Resolves https://github.com/BerriAI/litellm/issues/8085

* feat(bedrock/chat/converse_transformation.py): make converse config base config compatible

enables new 'converse_like' route

* feat(converse_transformation.py): enables using the proxy with converse like api endpoint

Resolves https://github.com/BerriAI/litellm/issues/8085
This commit is contained in:
Krish Dholakia 2025-01-29 20:53:37 -08:00 committed by GitHub
parent a57fad1e29
commit dad24f2b52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 182 additions and 51 deletions

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Control Model Access with SSO (Azure AD/Keycloak/etc.) # Control Model Access with OIDC (Azure AD/Keycloak/etc.)
:::info :::info

View file

@ -0,0 +1,5 @@
"""
Uses base_llm_http_handler to call the 'converse like' endpoint.
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
"""

View file

@ -0,0 +1,3 @@
"""
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
"""

View file

@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
import copy import copy
import time import time
import types import types
from typing import List, Literal, Optional, Tuple, Union, overload from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload
import httpx import httpx
@ -17,6 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt, _bedrock_converse_messages_pt,
_bedrock_tools_pt, _bedrock_tools_pt,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.bedrock import * from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
@ -42,7 +43,7 @@ global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions() all_global_regions = global_config.get_all_regions()
class AmazonConverseConfig: class AmazonConverseConfig(BaseConfig):
""" """
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
@ -193,9 +194,9 @@ class AmazonConverseConfig:
def map_openai_params( def map_openai_params(
self, self,
model: str,
non_default_params: dict, non_default_params: dict,
optional_params: dict, optional_params: dict,
model: str,
drop_params: bool, drop_params: bool,
messages: Optional[List[AllMessageValues]] = None, messages: Optional[List[AllMessageValues]] = None,
) -> dict: ) -> dict:
@ -254,25 +255,6 @@ class AmazonConverseConfig:
if _tool_choice_value is not None: if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value optional_params["tool_choice"] = _tool_choice_value
## VALIDATE REQUEST
"""
Bedrock doesn't support tool calling without `tools=` param specified.
"""
if (
"tools" not in non_default_params
and messages is not None
and has_tool_call_blocks(messages)
):
if litellm.modify_params:
optional_params["tools"] = add_dummy_tool(
custom_llm_provider="bedrock_converse"
)
else:
raise litellm.UnsupportedParamsError(
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
model="",
llm_provider="bedrock",
)
return optional_params return optional_params
@overload @overload
@ -352,8 +334,32 @@ class AmazonConverseConfig:
return InferenceConfig(**inference_params) return InferenceConfig(**inference_params)
def _transform_request_helper( def _transform_request_helper(
self, system_content_blocks: List[SystemContentBlock], optional_params: dict self,
system_content_blocks: List[SystemContentBlock],
optional_params: dict,
messages: Optional[List[AllMessageValues]] = None,
) -> CommonRequestObject: ) -> CommonRequestObject:
## VALIDATE REQUEST
"""
Bedrock doesn't support tool calling without `tools=` param specified.
"""
if (
"tools" not in optional_params
and messages is not None
and has_tool_call_blocks(messages)
):
if litellm.modify_params:
optional_params["tools"] = add_dummy_tool(
custom_llm_provider="bedrock_converse"
)
else:
raise litellm.UnsupportedParamsError(
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
model="",
llm_provider="bedrock",
)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
additional_request_keys = [] additional_request_keys = []
additional_request_params = {} additional_request_params = {}
@ -429,14 +435,12 @@ class AmazonConverseConfig:
) -> RequestObject: ) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages) messages, system_content_blocks = self._transform_system_message(messages)
## TRANSFORMATION ## ## TRANSFORMATION ##
# bedrock_messages: List[MessageBlock] = await asyncify(
# _bedrock_converse_messages_pt _data: CommonRequestObject = self._transform_request_helper(
# )( system_content_blocks=system_content_blocks,
# messages=messages, optional_params=optional_params,
# model=model, messages=messages,
# llm_provider="bedrock_converse", )
# user_continue_message=litellm_params.pop("user_continue_message", None),
# )
bedrock_messages = ( bedrock_messages = (
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async( await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
@ -447,15 +451,28 @@ class AmazonConverseConfig:
) )
) )
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data} data: RequestObject = {"messages": bedrock_messages, **_data}
return data return data
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
return cast(
dict,
self._transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
),
)
def _transform_request( def _transform_request(
self, self,
model: str, model: str,
@ -464,6 +481,13 @@ class AmazonConverseConfig:
litellm_params: dict, litellm_params: dict,
) -> RequestObject: ) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages) messages, system_content_blocks = self._transform_system_message(messages)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
)
## TRANSFORMATION ## ## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages, messages=messages,
@ -472,15 +496,38 @@ class AmazonConverseConfig:
user_continue_message=litellm_params.pop("user_continue_message", None), user_continue_message=litellm_params.pop("user_continue_message", None),
) )
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data} data: RequestObject = {"messages": bedrock_messages, **_data}
return data return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Logging,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return self._transform_response(
model=model,
response=raw_response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj,
optional_params=optional_params,
api_key=api_key,
data=request_data,
messages=messages,
print_verbose=None,
encoding=encoding,
)
def _transform_response( def _transform_response(
self, self,
model: str, model: str,
@ -489,12 +536,12 @@ class AmazonConverseConfig:
stream: bool, stream: bool,
logging_obj: Optional[Logging], logging_obj: Optional[Logging],
optional_params: dict, optional_params: dict,
api_key: str, api_key: Optional[str],
data: Union[dict, str], data: Union[dict, str],
messages: List, messages: List,
print_verbose, print_verbose: Optional[Callable],
encoding, encoding,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> ModelResponse:
## LOGGING ## LOGGING
if logging_obj is not None: if logging_obj is not None:
logging_obj.post_call( logging_obj.post_call(
@ -503,7 +550,7 @@ class AmazonConverseConfig:
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}")
json_mode: Optional[bool] = optional_params.pop("json_mode", None) json_mode: Optional[bool] = optional_params.pop("json_mode", None)
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
@ -652,3 +699,25 @@ class AmazonConverseConfig:
return model.split("/", 1)[1] return model.split("/", 1)[1]
return model return model
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(
message=error_message,
status_code=status_code,
headers=headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers

View file

@ -206,6 +206,7 @@ class BaseLLMHTTPHandler:
headers: Optional[dict] = {}, headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider) model=model, provider=litellm.LlmProviders(custom_llm_provider)
) )

View file

@ -2589,6 +2589,25 @@ def completion( # type: ignore # noqa: PLR0915
client=client, client=client,
api_base=api_base, api_base=api_base,
) )
elif "converse_like" in model:
model = model.replace("converse_like/", "")
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="bedrock",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
else: else:
model = model.replace("invoke/", "") model = model.replace("invoke/", "")
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -6021,8 +6021,11 @@ class ProviderConfigManager:
return litellm.PetalsConfig() return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider: elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model) base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models: if (
pass base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
return litellm.AmazonConverseConfig()
elif "amazon" in model: # amazon titan llms elif "amazon" in model: # amazon titan llms
return litellm.AmazonTitanConfig() return litellm.AmazonTitanConfig()
elif "meta" in model: # amazon / meta llms elif "meta" in model: # amazon / meta llms

View file

@ -2506,3 +2506,26 @@ async def test_bedrock_document_understanding(image_url):
) )
assert response is not None assert response is not None
assert response.choices[0].message.content != "" assert response.choices[0].message.content != ""
def test_bedrock_custom_proxy():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
messages=[{"content": "Tell me a joke", "role": "user"}],
api_key="Token",
client=client,
api_base="https://some-api-url/models",
)
except Exception as e:
print(e)
print(mock_post.call_args.kwargs)
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"

View file

@ -369,6 +369,17 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl" assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
def test_provider_config_manager_bedrock_converse_like():
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
config = ProviderConfigManager.get_provider_chat_config(
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
provider=LlmProviders.BEDROCK,
)
print(f"config: {config}")
assert isinstance(config, AmazonConverseConfig)
# def test_provider_config_manager(): # def test_provider_config_manager():
# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig # from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig