mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm dev 01 29 2025 p2 (#8102)
* docs: cleanup doc * feat(bedrock/): initial commit adding bedrock/converse_like/<model> route support allows routing to a converse like endpoint Resolves https://github.com/BerriAI/litellm/issues/8085 * feat(bedrock/chat/converse_transformation.py): make converse config base config compatible enables new 'converse_like' route * feat(converse_transformation.py): enables using the proxy with converse like api endpoint Resolves https://github.com/BerriAI/litellm/issues/8085
This commit is contained in:
parent
a57fad1e29
commit
dad24f2b52
12 changed files with 182 additions and 51 deletions
|
@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Control Model Access with SSO (Azure AD/Keycloak/etc.)
|
||||
# Control Model Access with OIDC (Azure AD/Keycloak/etc.)
|
||||
|
||||
:::info
|
||||
|
||||
|
|
5
litellm/llms/bedrock/chat/converse_like/handler.py
Normal file
5
litellm/llms/bedrock/chat/converse_like/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Uses base_llm_http_handler to call the 'converse like' endpoint.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
|
||||
"""
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
|
||||
"""
|
|
@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
|
|||
import copy
|
||||
import time
|
||||
import types
|
||||
from typing import List, Literal, Optional, Tuple, Union, overload
|
||||
from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -17,6 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.bedrock import *
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -42,7 +43,7 @@ global_config = AmazonBedrockGlobalConfig()
|
|||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
|
||||
class AmazonConverseConfig:
|
||||
class AmazonConverseConfig(BaseConfig):
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||
|
@ -193,9 +194,9 @@ class AmazonConverseConfig:
|
|||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> dict:
|
||||
|
@ -254,25 +255,6 @@ class AmazonConverseConfig:
|
|||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Bedrock doesn't support tool calling without `tools=` param specified.
|
||||
"""
|
||||
if (
|
||||
"tools" not in non_default_params
|
||||
and messages is not None
|
||||
and has_tool_call_blocks(messages)
|
||||
):
|
||||
if litellm.modify_params:
|
||||
optional_params["tools"] = add_dummy_tool(
|
||||
custom_llm_provider="bedrock_converse"
|
||||
)
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||||
model="",
|
||||
llm_provider="bedrock",
|
||||
)
|
||||
return optional_params
|
||||
|
||||
@overload
|
||||
|
@ -352,8 +334,32 @@ class AmazonConverseConfig:
|
|||
return InferenceConfig(**inference_params)
|
||||
|
||||
def _transform_request_helper(
|
||||
self, system_content_blocks: List[SystemContentBlock], optional_params: dict
|
||||
self,
|
||||
system_content_blocks: List[SystemContentBlock],
|
||||
optional_params: dict,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> CommonRequestObject:
|
||||
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Bedrock doesn't support tool calling without `tools=` param specified.
|
||||
"""
|
||||
if (
|
||||
"tools" not in optional_params
|
||||
and messages is not None
|
||||
and has_tool_call_blocks(messages)
|
||||
):
|
||||
if litellm.modify_params:
|
||||
optional_params["tools"] = add_dummy_tool(
|
||||
custom_llm_provider="bedrock_converse"
|
||||
)
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||||
model="",
|
||||
llm_provider="bedrock",
|
||||
)
|
||||
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
additional_request_keys = []
|
||||
additional_request_params = {}
|
||||
|
@ -429,14 +435,12 @@ class AmazonConverseConfig:
|
|||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
## TRANSFORMATION ##
|
||||
# bedrock_messages: List[MessageBlock] = await asyncify(
|
||||
# _bedrock_converse_messages_pt
|
||||
# )(
|
||||
# messages=messages,
|
||||
# model=model,
|
||||
# llm_provider="bedrock_converse",
|
||||
# user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
# )
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
bedrock_messages = (
|
||||
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
|
@ -447,15 +451,28 @@ class AmazonConverseConfig:
|
|||
)
|
||||
)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return cast(
|
||||
dict,
|
||||
self._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
),
|
||||
)
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -464,6 +481,13 @@ class AmazonConverseConfig:
|
|||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
|
@ -472,15 +496,38 @@ class AmazonConverseConfig:
|
|||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return self._transform_response(
|
||||
model=model,
|
||||
response=raw_response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=request_data,
|
||||
messages=messages,
|
||||
print_verbose=None,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -489,12 +536,12 @@ class AmazonConverseConfig:
|
|||
stream: bool,
|
||||
logging_obj: Optional[Logging],
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
api_key: Optional[str],
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
print_verbose: Optional[Callable],
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
|
@ -503,7 +550,7 @@ class AmazonConverseConfig:
|
|||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
|
@ -652,3 +699,25 @@ class AmazonConverseConfig:
|
|||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(
|
||||
message=error_message,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
|
|
@ -206,6 +206,7 @@ class BaseLLMHTTPHandler:
|
|||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
|
|
@ -2589,6 +2589,25 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
client=client,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif "converse_like" in model:
|
||||
model = model.replace("converse_like/", "")
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="bedrock",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
model = model.replace("invoke/", "")
|
||||
response = bedrock_chat_completion.completion(
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -6021,8 +6021,11 @@ class ProviderConfigManager:
|
|||
return litellm.PetalsConfig()
|
||||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
if base_model in litellm.bedrock_converse_models:
|
||||
pass
|
||||
if (
|
||||
base_model in litellm.bedrock_converse_models
|
||||
or "converse_like" in model
|
||||
):
|
||||
return litellm.AmazonConverseConfig()
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
return litellm.AmazonTitanConfig()
|
||||
elif "meta" in model: # amazon / meta llms
|
||||
|
|
|
@ -2506,3 +2506,26 @@ async def test_bedrock_document_understanding(image_url):
|
|||
)
|
||||
assert response is not None
|
||||
assert response.choices[0].message.content != ""
|
||||
|
||||
|
||||
def test_bedrock_custom_proxy():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
|
||||
messages=[{"content": "Tell me a joke", "role": "user"}],
|
||||
api_key="Token",
|
||||
client=client,
|
||||
api_base="https://some-api-url/models",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(mock_post.call_args.kwargs)
|
||||
mock_post.assert_called_once()
|
||||
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
|
||||
|
||||
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"
|
||||
|
|
|
@ -369,6 +369,17 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
|||
assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
|
||||
|
||||
|
||||
def test_provider_config_manager_bedrock_converse_like():
|
||||
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||
|
||||
config = ProviderConfigManager.get_provider_chat_config(
|
||||
model="bedrock/converse_like/us.amazon.nova-pro-v1:0",
|
||||
provider=LlmProviders.BEDROCK,
|
||||
)
|
||||
print(f"config: {config}")
|
||||
assert isinstance(config, AmazonConverseConfig)
|
||||
|
||||
|
||||
# def test_provider_config_manager():
|
||||
# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue