Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -78,7 +78,6 @@ from .llms import (
ai21,
aleph_alpha,
baseten,
bedrock,
clarifai,
cloudflare,
maritalk,
@ -96,7 +95,9 @@ from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.anthropic.completion import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere import embed as cohere_embed
@ -176,6 +177,7 @@ codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
vertex_chat_completion = VertexLLM()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
google_batch_embeddings = GoogleBatchEmbeddings()
@ -3151,6 +3153,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "bedrock"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3519,13 +3522,24 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "bedrock":
response = bedrock.embedding(
if isinstance(input, str):
transformed_input = [input]
else:
transformed_input = input
response = bedrock_embedding.embeddings(
model=model,
input=input,
input=transformed_input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
client=client,
timeout=timeout,
aembedding=aembedding,
litellm_params=litellm_params,
api_base=api_base,
print_verbose=print_verbose,
extra_headers=extra_headers,
)
elif custom_llm_provider == "triton":
if api_base is None:
@ -4493,7 +4507,7 @@ def image_generation(
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock.image_generation(
model_response = bedrock_image_generation.image_generation(
model=model,
prompt=prompt,
timeout=timeout,