mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor: make bedrock image transformation requests async (#7840)
* refactor: initial commit for using separate sync vs. async transformation routes for bedrock ensures no blocking calls e.g. when converting image url to b64 * perf(converse_transformation.py): make bedrock converse transformation async asyncify's the bedrock message transformation - useful for handling image urls for bedrock * fix(converse_handler.py): fix logging for async streaming * style: cleanup unused imports
This commit is contained in:
parent
32c8933935
commit
2b58f16fda
8 changed files with 266 additions and 93 deletions
|
@ -1056,6 +1056,7 @@ ALL_LITELLM_RESPONSE_TYPES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
from .llms.custom_llm import CustomLLM
|
from .llms.custom_llm import CustomLLM
|
||||||
|
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||||
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
||||||
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
||||||
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
||||||
|
@ -1130,7 +1131,7 @@ from .llms.bedrock.chat.invoke_handler import (
|
||||||
AmazonCohereChatConfig,
|
AmazonCohereChatConfig,
|
||||||
bedrock_tool_name_mappings,
|
bedrock_tool_name_mappings,
|
||||||
)
|
)
|
||||||
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
|
||||||
from .llms.bedrock.common_utils import (
|
from .llms.bedrock.common_utils import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -12,9 +12,11 @@ from litellm.caching.caching import DualCache
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from botocore.awsrequest import AWSPreparedRequest
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
else:
|
else:
|
||||||
Credentials = Any
|
Credentials = Any
|
||||||
|
AWSPreparedRequest = Any
|
||||||
|
|
||||||
|
|
||||||
class Boto3CredentialsInfo(BaseModel):
|
class Boto3CredentialsInfo(BaseModel):
|
||||||
|
@ -471,3 +473,32 @@ class BaseAWSLLM:
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_request_headers(
|
||||||
|
self,
|
||||||
|
credentials: Credentials,
|
||||||
|
aws_region_name: str,
|
||||||
|
extra_headers: Optional[dict],
|
||||||
|
endpoint_url: str,
|
||||||
|
data: str,
|
||||||
|
headers: dict,
|
||||||
|
) -> AWSPreparedRequest:
|
||||||
|
try:
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
if (
|
||||||
|
extra_headers is not None and "Authorization" in extra_headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
return prepped
|
||||||
|
|
|
@ -14,7 +14,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
from litellm.utils import CustomStreamWrapper, get_secret
|
from litellm.utils import CustomStreamWrapper, get_secret
|
||||||
|
|
||||||
from ..base_aws_llm import BaseAWSLLM
|
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||||
from ..common_utils import BedrockError
|
from ..common_utils import BedrockError
|
||||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||||
|
|
||||||
|
@ -41,7 +41,9 @@ def make_sync_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
raise BedrockError(
|
||||||
|
status_code=response.status_code, message=str(response.read())
|
||||||
|
)
|
||||||
|
|
||||||
if fake_stream:
|
if fake_stream:
|
||||||
model_response: (
|
model_response: (
|
||||||
|
@ -78,6 +80,7 @@ def make_sync_call(
|
||||||
|
|
||||||
|
|
||||||
class BedrockConverseLLM(BaseAWSLLM):
|
class BedrockConverseLLM(BaseAWSLLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -98,13 +101,13 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
data: str,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
|
credentials: Credentials,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
@ -112,10 +115,38 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
|
|
||||||
|
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
data = json.dumps(request_data)
|
||||||
|
|
||||||
|
prepped = self.get_request_headers(
|
||||||
|
credentials=credentials,
|
||||||
|
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||||
|
extra_headers=headers,
|
||||||
|
endpoint_url=api_base,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": dict(prepped.headers),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
completion_stream = await make_call(
|
completion_stream = await make_call(
|
||||||
client=client,
|
client=client,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
headers=headers,
|
headers=dict(prepped.headers),
|
||||||
data=data,
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -138,17 +169,47 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
data: str,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
|
credentials: Credentials,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers: dict = {},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
data = json.dumps(request_data)
|
||||||
|
|
||||||
|
prepped = self.get_request_headers(
|
||||||
|
credentials=credentials,
|
||||||
|
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||||
|
extra_headers=headers,
|
||||||
|
endpoint_url=api_base,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(prepped.headers)
|
||||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
_params = {}
|
_params = {}
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
|
@ -203,8 +264,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from botocore.auth import SigV4Auth
|
|
||||||
from botocore.awsrequest import AWSRequest
|
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
@ -237,6 +296,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||||
|
|
||||||
|
litellm_params["aws_region_name"] = aws_region_name
|
||||||
|
|
||||||
### SET REGION NAME ###
|
### SET REGION NAME ###
|
||||||
if aws_region_name is None:
|
if aws_region_name is None:
|
||||||
# check env #
|
# check env #
|
||||||
|
@ -281,7 +342,54 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=proxy_endpoint_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
json_mode=json_mode,
|
||||||
|
fake_stream=fake_stream,
|
||||||
|
credentials=credentials,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=proxy_endpoint_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream, # type: ignore
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
credentials=credentials,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
@ -292,20 +400,15 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
data = json.dumps(_data)
|
data = json.dumps(_data)
|
||||||
## COMPLETION CALL
|
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
prepped = self.get_request_headers(
|
||||||
if extra_headers is not None:
|
credentials=credentials,
|
||||||
headers = {"Content-Type": "application/json", **extra_headers}
|
aws_region_name=aws_region_name,
|
||||||
request = AWSRequest(
|
extra_headers=extra_headers,
|
||||||
method="POST", url=endpoint_url, data=data, headers=headers
|
endpoint_url=proxy_endpoint_url,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
sigv4.add_auth(request)
|
|
||||||
if (
|
|
||||||
extra_headers is not None and "Authorization" in extra_headers
|
|
||||||
): # prevent sigv4 from overwriting the auth header
|
|
||||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
|
||||||
prepped = request.prepare()
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -317,50 +420,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
"headers": prepped.headers,
|
"headers": prepped.headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
|
||||||
if acompletion:
|
|
||||||
if isinstance(client, HTTPHandler):
|
|
||||||
client = None
|
|
||||||
if stream is True:
|
|
||||||
return self.async_streaming(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
api_base=proxy_endpoint_url,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
stream=True,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=prepped.headers,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
json_mode=json_mode,
|
|
||||||
fake_stream=fake_stream,
|
|
||||||
) # type: ignore
|
|
||||||
### ASYNC COMPLETION
|
|
||||||
return self.async_completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
api_base=proxy_endpoint_url,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
stream=stream, # type: ignore
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=prepped.headers,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
_params = {}
|
_params = {}
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import List, Literal, Optional, Tuple, Union, overload
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
@ -347,14 +348,9 @@ class AmazonConverseConfig:
|
||||||
inference_params["topK"] = inference_params.pop("top_k")
|
inference_params["topK"] = inference_params.pop("top_k")
|
||||||
return InferenceConfig(**inference_params)
|
return InferenceConfig(**inference_params)
|
||||||
|
|
||||||
def _transform_request(
|
def _transform_request_helper(
|
||||||
self,
|
self, system_content_blocks: List[SystemContentBlock], optional_params: dict
|
||||||
model: str,
|
) -> CommonRequestObject:
|
||||||
messages: List[AllMessageValues],
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
) -> RequestObject:
|
|
||||||
messages, system_content_blocks = self._transform_system_message(messages)
|
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
additional_request_keys = []
|
additional_request_keys = []
|
||||||
additional_request_params = {}
|
additional_request_params = {}
|
||||||
|
@ -364,14 +360,6 @@ class AmazonConverseConfig:
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||||
## TRANSFORMATION ##
|
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
llm_provider="bedrock_converse",
|
|
||||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
for k, v in inference_params.items():
|
for k, v in inference_params.items():
|
||||||
|
@ -408,8 +396,7 @@ class AmazonConverseConfig:
|
||||||
if tool_choice_values is not None:
|
if tool_choice_values is not None:
|
||||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||||
|
|
||||||
_data: RequestObject = {
|
data: CommonRequestObject = {
|
||||||
"messages": bedrock_messages,
|
|
||||||
"additionalModelRequestFields": additional_request_params,
|
"additionalModelRequestFields": additional_request_params,
|
||||||
"system": system_content_blocks,
|
"system": system_content_blocks,
|
||||||
"inferenceConfig": self._transform_inference_params(
|
"inferenceConfig": self._transform_inference_params(
|
||||||
|
@ -422,13 +409,65 @@ class AmazonConverseConfig:
|
||||||
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
||||||
if request_guardrails_config is not None:
|
if request_guardrails_config is not None:
|
||||||
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
||||||
_data["guardrailConfig"] = guardrail_config
|
data["guardrailConfig"] = guardrail_config
|
||||||
|
|
||||||
# Tool Config
|
# Tool Config
|
||||||
if bedrock_tool_config is not None:
|
if bedrock_tool_config is not None:
|
||||||
_data["toolConfig"] = bedrock_tool_config
|
data["toolConfig"] = bedrock_tool_config
|
||||||
|
|
||||||
return _data
|
return data
|
||||||
|
|
||||||
|
async def _async_transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> RequestObject:
|
||||||
|
messages, system_content_blocks = self._transform_system_message(messages)
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
bedrock_messages: List[MessageBlock] = await asyncify(
|
||||||
|
_bedrock_converse_messages_pt
|
||||||
|
)(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
_data: CommonRequestObject = self._transform_request_helper(
|
||||||
|
system_content_blocks=system_content_blocks,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> RequestObject:
|
||||||
|
messages, system_content_blocks = self._transform_system_message(messages)
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
_data: CommonRequestObject = self._transform_request_helper(
|
||||||
|
system_content_blocks=system_content_blocks,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def _transform_response(
|
def _transform_response(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -2,4 +2,4 @@ model_list:
|
||||||
- model_name: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
|
- model_name: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
|
model: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
|
||||||
|
|
||||||
|
|
|
@ -161,16 +161,21 @@ class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||||
toolUse: ToolBlockDeltaEvent
|
toolUse: ToolBlockDeltaEvent
|
||||||
|
|
||||||
|
|
||||||
class RequestObject(TypedDict, total=False):
|
class CommonRequestObject(
|
||||||
|
TypedDict, total=False
|
||||||
|
): # common request object across sync + async flows
|
||||||
additionalModelRequestFields: dict
|
additionalModelRequestFields: dict
|
||||||
additionalModelResponseFieldPaths: List[str]
|
additionalModelResponseFieldPaths: List[str]
|
||||||
inferenceConfig: InferenceConfig
|
inferenceConfig: InferenceConfig
|
||||||
messages: Required[List[MessageBlock]]
|
|
||||||
system: List[SystemContentBlock]
|
system: List[SystemContentBlock]
|
||||||
toolConfig: ToolConfigBlock
|
toolConfig: ToolConfigBlock
|
||||||
guardrailConfig: Optional[GuardrailConfigBlock]
|
guardrailConfig: Optional[GuardrailConfigBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class RequestObject(CommonRequestObject, total=False):
|
||||||
|
messages: Required[List[MessageBlock]]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
|
|
@ -2391,3 +2391,41 @@ def test_process_bedrock_converse_image_block():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert block["document"] is not None
|
assert block["document"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_image_url_sync_client():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
import logging
|
||||||
|
from litellm import verbose_logger
|
||||||
|
|
||||||
|
verbose_logger.setLevel(level=logging.DEBUG)
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
await litellm.acompletion(
|
||||||
|
model="bedrock/us.amazon.nova-pro-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
|
@ -1387,7 +1387,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
[
|
[
|
||||||
# ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
# ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
||||||
# ["bedrock/cohere.command-r-plus-v1:0", None],
|
# ["bedrock/cohere.command-r-plus-v1:0", None],
|
||||||
# ["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
||||||
# ["anthropic.claude-instant-v1", None],
|
# ["anthropic.claude-instant-v1", None],
|
||||||
# ["mistral.mistral-7b-instruct-v0:2", None],
|
# ["mistral.mistral-7b-instruct-v0:2", None],
|
||||||
["bedrock/amazon.titan-tg1-large", None],
|
["bedrock/amazon.titan-tg1-large", None],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue