mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
parent
6fb82aaf75
commit
37f9705d6e
21 changed files with 1946 additions and 1659 deletions
|
@ -76,7 +76,7 @@ async def async_embedding(
|
|||
data: dict,
|
||||
input: list,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
api_base: str,
|
||||
|
@ -98,16 +98,35 @@ async def async_embedding(
|
|||
)
|
||||
## COMPLETION CALL
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
except httpx.HTTPStatusError as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=e.response.text,
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
original_response=response.text,
|
||||
)
|
||||
|
||||
embeddings = response.json()["embeddings"]
|
||||
|
@ -130,27 +149,22 @@ def embedding(
|
|||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
data: Optional[dict] = None,
|
||||
complete_api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
data = data or {"model": model, "texts": input, **optional_params}
|
||||
|
||||
if "3" in model and "input_type" not in data:
|
||||
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
||||
data["input_type"] = "search_document"
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## ROUTING
|
||||
if aembedding is True:
|
||||
return async_embedding(
|
||||
|
@ -166,9 +180,18 @@ def embedding(
|
|||
headers=headers,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue