litellm-mirror/litellm/llms/bedrock/image_generation.py
Krish Dholakia 37f9705d6e
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
2024-09-01 13:29:58 -07:00

127 lines
4.1 KiB
Python

"""
Handles image gen calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from typing import List
from openai.types.image import Image
import litellm
from litellm.types.utils import ImageResponse
from .common_utils import BedrockError, init_bedrock_client
def image_generation(
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
timeout=None,
logging_obj=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body}, # type: ignore
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response