Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider):
# test_cohere_embedding3()
def test_bedrock_embedding_titan():
@pytest.mark.parametrize(
"model",
[
"bedrock/amazon.titan-embed-text-v1",
"bedrock/amazon.titan-embed-image-v1",
"bedrock/amazon.titan-embed-text-v2:0",
],
)
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.asyncio
async def test_bedrock_embedding_titan(model, sync_mode):
try:
# this tests if we support str input for bedrock embedding
litellm.set_verbose = True
@ -320,16 +330,23 @@ def test_bedrock_embedding_titan():
current_time = str(time.time())
# DO NOT MAKE THE INPUT A LIST in this test
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
print(f"response:", response)
if sync_mode:
response = embedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
else:
response = await litellm.aembedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
print("response:", response)
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
@ -339,13 +356,20 @@ def test_bedrock_embedding_titan():
start_time = time.time()
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
if sync_mode:
response = embedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
else:
response = await litellm.aembedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
print(response)
end_time = time.time()
print(response._hidden_params)
print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time < 0.1
@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}',
):
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}',
):
litellm.embedding(
model="amazon.titan-embed-text-v1",