mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
recommit changes with correct email address.
This commit is contained in:
parent
b82af5b826
commit
43a3d9ac10
2 changed files with 93 additions and 3 deletions
|
@ -5,6 +5,7 @@ Handles embedding calls to Bedrock's `/invoke` endpoint
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -348,6 +349,16 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||||
|
|
||||||
### TRANSFORMATION ###
|
### TRANSFORMATION ###
|
||||||
|
unencoded_model_id = (
|
||||||
|
optional_params.pop("model_id", None) or model
|
||||||
|
) # default to model if not passed
|
||||||
|
modelId = urllib.parse.quote(unencoded_model_id, safe="")
|
||||||
|
aws_region_name = self._get_aws_region_name(
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
model_id=unencoded_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params = {
|
inference_params = {
|
||||||
|
@ -358,9 +369,6 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
inference_params.pop(
|
inference_params.pop(
|
||||||
"user", None
|
"user", None
|
||||||
) # make sure user is not passed in for bedrock call
|
) # make sure user is not passed in for bedrock call
|
||||||
modelId = (
|
|
||||||
optional_params.pop("model_id", None) or model
|
|
||||||
) # default to model if not passed
|
|
||||||
|
|
||||||
data: Optional[CohereEmbeddingRequest] = None
|
data: Optional[CohereEmbeddingRequest] = None
|
||||||
batch_data: Optional[List] = None
|
batch_data: Optional[List] = None
|
||||||
|
|
82
tests/litellm/llms/bedrock/embed/test_embedding.py
Normal file
82
tests/litellm/llms/bedrock/embed/test_embedding.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from litellm.types.utils import Embedding
|
||||||
|
from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
_mock_model_id = "arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123"
|
||||||
|
_mock_app_ip_url = "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123412341234%3Aapplication-inference-profile%2Fabc123123/invoke"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_embedding_response(model: str) -> EmbeddingResponse:
|
||||||
|
return EmbeddingResponse(
|
||||||
|
model=model,
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=1,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=1,
|
||||||
|
completion_tokens_details=None,
|
||||||
|
prompt_tokens_details=None
|
||||||
|
),
|
||||||
|
data=[
|
||||||
|
Embedding(
|
||||||
|
embedding=[-0.671875, 0.291015625, -0.1826171875, 0.8828125],
|
||||||
|
index=0,
|
||||||
|
object="embedding"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"amazon.titan-embed-text-v1",
|
||||||
|
"amazon.titan-embed-text-v2:0"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_bedrock_embedding_titan_app_profile(model: str):
|
||||||
|
with patch.object(bedrock_embedding, '_single_func_embeddings') as mock_method:
|
||||||
|
mock_method.return_value = _get_mock_embedding_response(model=model)
|
||||||
|
resp = embedding(
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
model=model,
|
||||||
|
model_id=_mock_model_id,
|
||||||
|
input=["tester"],
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_access_key_id="mockaws_access_key_id",
|
||||||
|
aws_secret_access_key="mockaws_secret_access_key"
|
||||||
|
)
|
||||||
|
assert mock_method.call_args.kwargs['endpoint_url'] == _mock_app_ip_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"cohere.embed-english-v3",
|
||||||
|
"cohere.embed-multilingual-v3"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_bedrock_embedding_cohere_app_profile(model: str):
|
||||||
|
with patch("litellm.llms.bedrock.embed.embedding.cohere_embedding") as mock_cohere_embedding:
|
||||||
|
mock_cohere_embedding.return_value = _get_mock_embedding_response(model=model)
|
||||||
|
resp = embedding(
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
model=model,
|
||||||
|
model_id=_mock_model_id,
|
||||||
|
input=["tester"],
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_access_key_id="mockaws_access_key_id",
|
||||||
|
aws_secret_access_key="mockaws_secret_access_key"
|
||||||
|
)
|
||||||
|
assert mock_cohere_embedding.call_args.kwargs['complete_api_base'] == _mock_app_ip_url
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue