From 43a3d9ac10811b690e44d9399f2149e2e326dd1c Mon Sep 17 00:00:00 2001 From: btemplep Date: Thu, 24 Apr 2025 09:47:09 -0400 Subject: [PATCH] recommit changes with correct email address. --- litellm/llms/bedrock/embed/embedding.py | 14 +++- .../llms/bedrock/embed/test_embedding.py | 82 +++++++++++++++++++ 2 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 tests/litellm/llms/bedrock/embed/test_embedding.py diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index 9e4e4e22d0..c930620270 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -5,6 +5,7 @@ Handles embedding calls to Bedrock's `/invoke` endpoint import copy import json from typing import Any, Callable, List, Optional, Tuple, Union +import urllib.parse import httpx @@ -348,6 +349,16 @@ class BedrockEmbedding(BaseAWSLLM): credentials, aws_region_name = self._load_credentials(optional_params) ### TRANSFORMATION ### + unencoded_model_id = ( + optional_params.pop("model_id", None) or model + ) # default to model if not passed + modelId = urllib.parse.quote(unencoded_model_id, safe="") + aws_region_name = self._get_aws_region_name( + optional_params=optional_params, + model=model, + model_id=unencoded_model_id, + ) + provider = model.split(".")[0] inference_params = copy.deepcopy(optional_params) inference_params = { @@ -358,9 +369,6 @@ class BedrockEmbedding(BaseAWSLLM): inference_params.pop( "user", None ) # make sure user is not passed in for bedrock call - modelId = ( - optional_params.pop("model_id", None) or model - ) # default to model if not passed data: Optional[CohereEmbeddingRequest] = None batch_data: Optional[List] = None diff --git a/tests/litellm/llms/bedrock/embed/test_embedding.py b/tests/litellm/llms/bedrock/embed/test_embedding.py new file mode 100644 index 0000000000..516f19b98a --- /dev/null +++ b/tests/litellm/llms/bedrock/embed/test_embedding.py @@ -0,0 +1,82 @@ + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../../../../..") +) # Adds the parent directory to the system path +from unittest.mock import patch + +import pytest + +from litellm.types.utils import Embedding +from litellm.main import bedrock_embedding, embedding, EmbeddingResponse, Usage + + +_mock_model_id = "arn:aws:bedrock:us-east-1:123412341234:application-inference-profile/abc123123" +_mock_app_ip_url = "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123412341234%3Aapplication-inference-profile%2Fabc123123/invoke" + + +def _get_mock_embedding_response(model: str) -> EmbeddingResponse: + return EmbeddingResponse( + model=model, + usage=Usage( + prompt_tokens=1, + completion_tokens=0, + total_tokens=1, + completion_tokens_details=None, + prompt_tokens_details=None + ), + data=[ + Embedding( + embedding=[-0.671875, 0.291015625, -0.1826171875, 0.8828125], + index=0, + object="embedding" + ) + ] + ) + + +@pytest.mark.parametrize( + "model", + [ + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0" + ] +) +def test_bedrock_embedding_titan_app_profile(model: str): + with patch.object(bedrock_embedding, '_single_func_embeddings') as mock_method: + mock_method.return_value = _get_mock_embedding_response(model=model) + resp = embedding( + custom_llm_provider="bedrock", + model=model, + model_id=_mock_model_id, + input=["tester"], + aws_region_name="us-east-1", + aws_access_key_id="mockaws_access_key_id", + aws_secret_access_key="mockaws_secret_access_key" + ) + assert mock_method.call_args.kwargs['endpoint_url'] == _mock_app_ip_url + + +@pytest.mark.parametrize( + "model", + [ + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3" + ] +) +def test_bedrock_embedding_cohere_app_profile(model: str): + with patch("litellm.llms.bedrock.embed.embedding.cohere_embedding") as mock_cohere_embedding: + mock_cohere_embedding.return_value = _get_mock_embedding_response(model=model) + resp = embedding( + custom_llm_provider="bedrock", + model=model, + model_id=_mock_model_id, + input=["tester"], + aws_region_name="us-east-1", + aws_access_key_id="mockaws_access_key_id", + aws_secret_access_key="mockaws_secret_access_key" + ) + assert mock_cohere_embedding.call_args.kwargs['complete_api_base'] == _mock_app_ip_url +