mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
parent
6fb82aaf75
commit
37f9705d6e
21 changed files with 1946 additions and 1659 deletions
|
@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider):
|
|||
# test_cohere_embedding3()
|
||||
|
||||
|
||||
def test_bedrock_embedding_titan():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/amazon.titan-embed-text-v1",
|
||||
"bedrock/amazon.titan-embed-image-v1",
|
||||
"bedrock/amazon.titan-embed-text-v2:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_embedding_titan(model, sync_mode):
|
||||
try:
|
||||
# this tests if we support str input for bedrock embedding
|
||||
litellm.set_verbose = True
|
||||
|
@ -320,16 +330,23 @@ def test_bedrock_embedding_titan():
|
|||
|
||||
current_time = str(time.time())
|
||||
# DO NOT MAKE THE INPUT A LIST in this test
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
print(f"response:", response)
|
||||
if sync_mode:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
print("response:", response)
|
||||
assert isinstance(
|
||||
response["data"][0]["embedding"], list
|
||||
), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
assert all(
|
||||
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||
), "Expected response to be a list of floats"
|
||||
|
@ -339,13 +356,20 @@ def test_bedrock_embedding_titan():
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
if sync_mode:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
print(response)
|
||||
|
||||
end_time = time.time()
|
||||
print(response._hidden_params)
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
assert end_time - start_time < 0.1
|
||||
|
@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
|
|||
|
||||
with pytest.raises(
|
||||
litellm.BadRequestError,
|
||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}',
|
||||
):
|
||||
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
|
||||
|
||||
with pytest.raises(
|
||||
litellm.BadRequestError,
|
||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}',
|
||||
):
|
||||
litellm.embedding(
|
||||
model="amazon.titan-embed-text-v1",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue