Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -76,7 +76,7 @@ async def async_embedding(
data: dict,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
@ -98,16 +98,35 @@ async def async_embedding(
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))
except httpx.HTTPStatusError as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=e.response.text,
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
original_response=response.text,
)
embeddings = response.json()["embeddings"]
@ -130,27 +149,22 @@ def embedding(
optional_params: dict,
headers: dict,
encoding: Any,
data: Optional[dict] = None,
complete_api_base: Optional[str] = None,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed"
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
model = model
data = {"model": model, "texts": input, **optional_params}
data = data or {"model": model, "texts": input, **optional_params}
if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
data["input_type"] = "search_document"
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## ROUTING
if aembedding is True:
return async_embedding(
@ -166,9 +180,18 @@ def embedding(
headers=headers,
encoding=encoding,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(