recommit changes with correct email address.

This commit is contained in:
btemplep 2025-04-24 09:47:09 -04:00
parent b82af5b826
commit 43a3d9ac10
2 changed files with 93 additions and 3 deletions

View file

@ -5,6 +5,7 @@ Handles embedding calls to Bedrock's `/invoke` endpoint
import copy import copy
import json import json
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import urllib.parse
import httpx import httpx
@ -348,6 +349,16 @@ class BedrockEmbedding(BaseAWSLLM):
credentials, aws_region_name = self._load_credentials(optional_params) credentials, aws_region_name = self._load_credentials(optional_params)
### TRANSFORMATION ### ### TRANSFORMATION ###
unencoded_model_id = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
modelId = urllib.parse.quote(unencoded_model_id, safe="")
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=unencoded_model_id,
)
provider = model.split(".")[0] provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params = { inference_params = {
@ -358,9 +369,6 @@ class BedrockEmbedding(BaseAWSLLM):
inference_params.pop( inference_params.pop(
"user", None "user", None
) # make sure user is not passed in for bedrock call ) # make sure user is not passed in for bedrock call
modelId = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
data: Optional[CohereEmbeddingRequest] = None data: Optional[CohereEmbeddingRequest] = None
batch_data: Optional[List] = None batch_data: Optional[List] = None

View file

@ -0,0 +1,82 @@
import os
import sys
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from unittest.mock import patch
import pytest
from litellm.types.utils import Embedding
from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage
_mock_model_id = "arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123"
_mock_app_ip_url = "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123412341234%3Aapplication-inference-profile%2Fabc123123/invoke"
def _get_mock_embedding_response(model: str) -> EmbeddingResponse:
return EmbeddingResponse(
model=model,
usage=Usage(
prompt_tokens=1,
completion_tokens=0,
total_tokens=1,
completion_tokens_details=None,
prompt_tokens_details=None
),
data=[
Embedding(
embedding=[-0.671875, 0.291015625, -0.1826171875, 0.8828125],
index=0,
object="embedding"
)
]
)
@pytest.mark.parametrize(
"model",
[
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0"
]
)
def test_bedrock_embedding_titan_app_profile(model: str):
with patch.object(bedrock_embedding, '_single_func_embeddings') as mock_method:
mock_method.return_value = _get_mock_embedding_response(model=model)
resp = embedding(
custom_llm_provider="bedrock",
model=model,
model_id=_mock_model_id,
input=["tester"],
aws_region_name="us-east-1",
aws_access_key_id="mockaws_access_key_id",
aws_secret_access_key="mockaws_secret_access_key"
)
assert mock_method.call_args.kwargs['endpoint_url'] == _mock_app_ip_url
@pytest.mark.parametrize(
"model",
[
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3"
]
)
def test_bedrock_embedding_cohere_app_profile(model: str):
with patch("litellm.llms.bedrock.embed.embedding.cohere_embedding") as mock_cohere_embedding:
mock_cohere_embedding.return_value = _get_mock_embedding_response(model=model)
resp = embedding(
custom_llm_provider="bedrock",
model=model,
model_id=_mock_model_id,
input=["tester"],
aws_region_name="us-east-1",
aws_access_key_id="mockaws_access_key_id",
aws_secret_access_key="mockaws_secret_access_key"
)
assert mock_cohere_embedding.call_args.kwargs['complete_api_base'] == _mock_app_ip_url