Litellm dev 01 29 2025 p2 (#8102)

* docs: cleanup doc

* feat(bedrock/): initial commit adding bedrock/converse_like/<model> route support

allows routing to a converse like endpoint

Resolves https://github.com/BerriAI/litellm/issues/8085

* feat(bedrock/chat/converse_transformation.py): make converse config base config compatible

enables new 'converse_like' route

* feat(converse_transformation.py): enables using the proxy with converse like api endpoint

Resolves https://github.com/BerriAI/litellm/issues/8085
This commit is contained in:
Krish Dholakia 2025-01-29 20:53:37 -08:00 committed by GitHub
parent a57fad1e29
commit dad24f2b52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 182 additions and 51 deletions

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Control Model Access with SSO (Azure AD/Keycloak/etc.)
# Control Model Access with OIDC (Azure AD/Keycloak/etc.)
:::info

View file

@ -0,0 +1,5 @@
"""
Uses base_llm_http_handler to call the 'converse like' endpoint.
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
"""

View file

@ -0,0 +1,3 @@
"""
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
"""

View file

@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
import copy
import time
import types
from typing import List, Literal, Optional, Tuple, Union, overload
from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload
import httpx
@ -17,6 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import (
AllMessageValues,
@ -42,7 +43,7 @@ global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
class AmazonConverseConfig:
class AmazonConverseConfig(BaseConfig):
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
@ -193,9 +194,9 @@ class AmazonConverseConfig:
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
@ -254,25 +255,6 @@ class AmazonConverseConfig:
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
## VALIDATE REQUEST
"""
Bedrock doesn't support tool calling without `tools=` param specified.
"""
if (
"tools" not in non_default_params
and messages is not None
and has_tool_call_blocks(messages)
):
if litellm.modify_params:
optional_params["tools"] = add_dummy_tool(
custom_llm_provider="bedrock_converse"
)
else:
raise litellm.UnsupportedParamsError(
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
model="",
llm_provider="bedrock",
)
return optional_params
@overload
@ -352,8 +334,32 @@ class AmazonConverseConfig:
return InferenceConfig(**inference_params)
def _transform_request_helper(
self, system_content_blocks: List[SystemContentBlock], optional_params: dict
self,
system_content_blocks: List[SystemContentBlock],
optional_params: dict,
messages: Optional[List[AllMessageValues]] = None,
) -> CommonRequestObject:
## VALIDATE REQUEST
"""
Bedrock doesn't support tool calling without `tools=` param specified.
"""
if (
"tools" not in optional_params
and messages is not None
and has_tool_call_blocks(messages)
):
if litellm.modify_params:
optional_params["tools"] = add_dummy_tool(
custom_llm_provider="bedrock_converse"
)
else:
raise litellm.UnsupportedParamsError(
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
model="",
llm_provider="bedrock",
)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
@ -429,14 +435,12 @@ class AmazonConverseConfig:
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
## TRANSFORMATION ##
# bedrock_messages: List[MessageBlock] = await asyncify(
# _bedrock_converse_messages_pt
# )(
# messages=messages,
# model=model,
# llm_provider="bedrock_converse",
# user_continue_message=litellm_params.pop("user_continue_message", None),
# )
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
)
bedrock_messages = (
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
@ -447,15 +451,28 @@ class AmazonConverseConfig:
)
)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data}
return data
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
return cast(
dict,
self._transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
),
)
def _transform_request(
self,
model: str,
@ -464,6 +481,13 @@ class AmazonConverseConfig:
litellm_params: dict,
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
)
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
@ -472,15 +496,38 @@ class AmazonConverseConfig:
user_continue_message=litellm_params.pop("user_continue_message", None),
)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data}
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Logging,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return self._transform_response(
model=model,
response=raw_response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj,
optional_params=optional_params,
api_key=api_key,
data=request_data,
messages=messages,
print_verbose=None,
encoding=encoding,
)
def _transform_response(
self,
model: str,
@ -489,12 +536,12 @@ class AmazonConverseConfig:
stream: bool,
logging_obj: Optional[Logging],
optional_params: dict,
api_key: str,
api_key: Optional[str],
data: Union[dict, str],
messages: List,
print_verbose,
print_verbose: Optional[Callable],
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
) -> ModelResponse:
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
@ -503,7 +550,7 @@ class AmazonConverseConfig:
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
## RESPONSE OBJECT
try:
@ -652,3 +699,25 @@ class AmazonConverseConfig:
return model.split("/", 1)[1]
return model
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(
message=error_message,
status_code=status_code,
headers=headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers

View file

@ -206,6 +206,7 @@ class BaseLLMHTTPHandler:
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)

View file

@ -2589,6 +2589,25 @@ def completion( # type: ignore # noqa: PLR0915
client=client,
api_base=api_base,
)
elif "converse_like" in model:
model = model.replace("converse_like/", "")
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="bedrock",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
else:
model = model.replace("invoke/", "")
response = bedrock_chat_completion.completion(

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -6021,8 +6021,11 @@ class ProviderConfigManager:
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models:
pass
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
return litellm.AmazonConverseConfig()
elif "amazon" in model: # amazon titan llms
return litellm.AmazonTitanConfig()
elif "meta" in model: # amazon / meta llms

View file

@ -2506,3 +2506,26 @@ async def test_bedrock_document_understanding(image_url):
)
assert response is not None
assert response.choices[0].message.content != ""
def test_bedrock_custom_proxy():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
messages=[{"content": "Tell me a joke", "role": "user"}],
api_key="Token",
client=client,
api_base="https://some-api-url/models",
)
except Exception as e:
print(e)
print(mock_post.call_args.kwargs)
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"

View file

@ -369,6 +369,17 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
def test_provider_config_manager_bedrock_converse_like():
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
config = ProviderConfigManager.get_provider_chat_config(
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
provider=LlmProviders.BEDROCK,
)
print(f"config: {config}")
assert isinstance(config, AmazonConverseConfig)
# def test_provider_config_manager():
# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig