add bedrock image gen async support

This commit is contained in:
Ishaan Jaff 2024-11-08 13:17:43 -08:00
parent 3d1c305401
commit 64c3c4906c
4 changed files with 245 additions and 130 deletions

View file

@ -1,16 +1,28 @@
import hashlib
import json
import os
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache, InMemoryCache
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.main import get_secret, get_secret_str
from .base import BaseLLM
if TYPE_CHECKING:
from botocore.credentials import Credentials
else:
Credentials = Any
class Boto3CredentialsInfo(BaseModel):
credentials: Credentials
aws_region_name: str
aws_bedrock_runtime_endpoint: Optional[str]
class AwsAuthError(Exception):
def __init__(self, status_code, message):
@ -311,3 +323,74 @@ class BaseAWSLLM(BaseLLM):
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
self, optional_params: dict
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
Args:
optional_params (dict): Optional parameters for the model call
Returns:
Credentials: Boto3 credentials object
"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return Boto3CredentialsInfo(
credentials=credentials,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)

View file

@ -0,0 +1,158 @@
import copy
import json
import os
from typing import Any, List, Optional
import httpx
from openai.types.image import Image
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, _get_httpx_client
from litellm.types.utils import ImageResponse
from litellm.utils import print_verbose
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
class BedrockImageGeneration(BaseAWSLLM):
"""
Bedrock Image Generation handler
"""
def image_generation( # noqa: PLR0915
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Any] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
### SET RUNTIME ENDPOINT ###
modelId = model
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
# transform request
### FORMAT IMAGE GENERATION INPUT ###
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
try:
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
response_body = response.json()
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose("raw model_response: %s", response.text)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response
async def async_image_generation(self):
pass

View file

@ -1,127 +0,0 @@
"""
Handles image gen calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from typing import Any, List
from openai.types.image import Image
import litellm
from litellm.types.utils import ImageResponse
from .common_utils import BedrockError, init_bedrock_client
def image_generation(
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body}, # type: ignore
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -108,9 +108,9 @@ from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed
@ -214,6 +214,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
bedrock_image_generation = BedrockImageGeneration()
vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding()