Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -854,6 +854,7 @@ def client(original_function):
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None:
print_verbose("Cache Hit!")
if "detail" in cached_result:
# implies an error occurred
pass
@ -935,7 +936,10 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
else:
print_verbose(
"Cache Miss! on key - {}".format(preset_cache_key)
)
# CHECK MAX TOKENS
if (
kwargs.get("max_tokens", None) is not None
@ -1005,7 +1009,7 @@ def client(original_function):
litellm.cache is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) != True):
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
@ -1404,10 +1408,10 @@ def client(original_function):
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True:
if "stream" in kwargs and kwargs["stream"] is True:
if (
"complete_response" in kwargs
and kwargs["complete_response"] == True
and kwargs["complete_response"] is True
):
chunks = []
for idx, chunk in enumerate(result):
@ -11734,3 +11738,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
return True
return False
def is_base64_encoded(s: str) -> bool:
try:
# Try to decode the string
decoded_bytes = base64.b64decode(s, validate=True)
# Check if the original string can be re-encoded to the same string
return base64.b64encode(decoded_bytes).decode("utf-8") == s
except Exception:
return False