mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
"""
|
|
Handles image gen calls to Bedrock's `/invoke` endpoint
|
|
"""
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
from typing import List
|
|
|
|
from openai.types.image import Image
|
|
|
|
import litellm
|
|
from litellm.types.utils import ImageResponse
|
|
|
|
from .common_utils import BedrockError, init_bedrock_client
|
|
|
|
|
|
def image_generation(
|
|
model: str,
|
|
prompt: str,
|
|
model_response: ImageResponse,
|
|
optional_params: dict,
|
|
timeout=None,
|
|
logging_obj=None,
|
|
aimg_generation=False,
|
|
):
|
|
"""
|
|
Bedrock Image Gen endpoint support
|
|
"""
|
|
### BOTO3 INIT ###
|
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
|
"aws_bedrock_runtime_endpoint", None
|
|
)
|
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
|
|
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
|
client = init_bedrock_client(
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
aws_region_name=aws_region_name,
|
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
|
aws_web_identity_token=aws_web_identity_token,
|
|
aws_role_name=aws_role_name,
|
|
aws_session_name=aws_session_name,
|
|
timeout=timeout,
|
|
)
|
|
|
|
### FORMAT IMAGE GENERATION INPUT ###
|
|
modelId = model
|
|
provider = model.split(".")[0]
|
|
inference_params = copy.deepcopy(optional_params)
|
|
inference_params.pop(
|
|
"user", None
|
|
) # make sure user is not passed in for bedrock call
|
|
data = {}
|
|
if provider == "stability":
|
|
prompt = prompt.replace(os.linesep, " ")
|
|
## LOAD CONFIG
|
|
config = litellm.AmazonStabilityConfig.get_config()
|
|
for k, v in config.items():
|
|
if (
|
|
k not in inference_params
|
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
inference_params[k] = v
|
|
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
|
else:
|
|
raise BedrockError(
|
|
status_code=422, message=f"Unsupported model={model}, passed in"
|
|
)
|
|
|
|
body = json.dumps(data).encode("utf-8")
|
|
## LOGGING
|
|
request_str = f"""
|
|
response = client.invoke_model(
|
|
body={body}, # type: ignore
|
|
modelId={modelId},
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
)""" # type: ignore
|
|
logging_obj.pre_call(
|
|
input=prompt,
|
|
api_key="", # boto3 is used for init.
|
|
additional_args={
|
|
"complete_input_dict": {"model": modelId, "texts": prompt},
|
|
"request_str": request_str,
|
|
},
|
|
)
|
|
try:
|
|
response = client.invoke_model(
|
|
body=body,
|
|
modelId=modelId,
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
)
|
|
response_body = json.loads(response.get("body").read())
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key="",
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=json.dumps(response_body),
|
|
)
|
|
except Exception as e:
|
|
raise BedrockError(
|
|
message=f"Embedding Error with model {model}: {e}", status_code=500
|
|
)
|
|
|
|
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
|
if response_body is None:
|
|
raise Exception("Error in response object format")
|
|
|
|
if model_response is None:
|
|
model_response = ImageResponse()
|
|
|
|
image_list: List[Image] = []
|
|
for artifact in response_body["artifacts"]:
|
|
_image = Image(b64_json=artifact["base64"])
|
|
image_list.append(_image)
|
|
|
|
model_response.data = image_list
|
|
return model_response
|