fix(factory.py): handle bedrock claude image url's

This commit is contained in:
Krrish Dholakia 2024-06-07 10:04:03 -07:00
parent 21f4329b65
commit e66b3d264f
7 changed files with 58 additions and 32 deletions

1
.gitignore vendored
View file

@ -59,3 +59,4 @@ myenv/*
litellm/proxy/_experimental/out/404/index.html litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html litellm/proxy/_experimental/out/onboarding/index.html
litellm/tests/log.txt

View file

@ -16,11 +16,11 @@ repos:
name: Check if files match name: Check if files match
entry: python3 ci_cd/check_files_match.py entry: python3 ci_cd/check_files_match.py
language: system language: system
- repo: local # - repo: local
hooks: # hooks:
- id: mypy # - id: mypy
name: mypy # name: mypy
entry: python3 -m mypy --ignore-missing-imports # entry: python3 -m mypy --ignore-missing-imports
language: system # language: system
types: [python] # types: [python]
files: ^litellm/ # files: ^litellm/

View file

@ -1158,6 +1158,7 @@ class AmazonConverseConfig:
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"extra_headers",
] ]
if ( if (

View file

@ -1621,7 +1621,7 @@ from litellm.types.llms.bedrock import (
) )
def get_image_details(image_url) -> Tuple[bytes, str]: def get_image_details(image_url) -> Tuple[str, str]:
try: try:
import base64 import base64
@ -1637,7 +1637,7 @@ def get_image_details(image_url) -> Tuple[bytes, str]:
) )
# Convert the image content to base64 bytes # Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content) base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type # Get mime-type
mime_type = content_type.split("/")[ mime_type = content_type.split("/")[
@ -1659,18 +1659,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image> # base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",") image_metadata, img_without_base_64 = image_url.split(",")
image_format = image_metadata.split("/")[1]
# read mime_type from img_without_base_64=data:image/jpeg;base64 # read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression # Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata) mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match: if mime_type_match:
mime_type = mime_type_match.group(1) mime_type = mime_type_match.group(1)
image_format = mime_type.split("/")[1]
else: else:
mime_type = "jpeg" mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64) image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=decoded_img) _blob = BedrockImageSourceBlock(bytes=img_without_base_64)
supported_image_formats = ( supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types() litellm.AmazonConverseConfig().get_supported_image_types()
) )
@ -1701,7 +1700,8 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
) )
else: else:
raise ValueError( raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string" "Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
) )

View file

@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner", reason="Cannot run without being in CircleCI Runner",
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
@pytest.mark.parametrize(
"image_url",
[
"",
"https://avatars.githubusercontent.com/u/29436595?v=",
],
)
def test_bedrock_claude_3(image_url):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
data = { data = {
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
{ {
"image_url": { "image_url": {
"detail": "high", "detail": "high",
"url": "", "url": image_url,
}, },
"type": "image_url", "type": "image_url",
}, },
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
# Add any assertions here to check the response # Add any assertions here to check the response
assert len(response.choices) > 0 assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0 assert len(response.choices[0].message.content) > 0
except RateLimitError: except RateLimitError:
pass pass
except Exception as e: except Exception as e:
@ -552,7 +560,7 @@ def test_bedrock_ptu():
assert "url" in mock_client_post.call_args.kwargs assert "url" in mock_client_post.call_args.kwargs
assert ( assert (
mock_client_post.call_args.kwargs["url"] mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
) )
mock_client_post.assert_called_once() mock_client_post.assert_called_once()

View file

@ -16,7 +16,7 @@ class SystemContentBlock(TypedDict):
class ImageSourceBlock(TypedDict): class ImageSourceBlock(TypedDict):
bytes: Optional[bytes] bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict): class ImageBlock(TypedDict):

View file

@ -4066,7 +4066,9 @@ def openai_token_counter(
for c in value: for c in value:
if c["type"] == "text": if c["type"] == "text":
text += c["text"] text += c["text"]
num_tokens += len(encoding.encode(c["text"], disallowed_special=())) num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url": elif c["type"] == "image_url":
if isinstance(c["image_url"], dict): if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"] image_url_dict = c["image_url"]
@ -5639,16 +5641,30 @@ def get_optional_params(
optional_params["stream"] = stream optional_params["stream"] = stream
elif "anthropic" in model: elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params( if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
model=model, if model.startswith("anthropic.claude-3"):
non_default_params=non_default_params, optional_params = (
optional_params=optional_params, litellm.AmazonAnthropicClaude3Config().map_openai_params(
drop_params=( non_default_params=non_default_params,
drop_params optional_params=optional_params,
if drop_params is not None and isinstance(drop_params, bool) )
else False )
), else:
) optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
else: # bedrock httpx route
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "amazon" in model: # amazon titan llms elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large