diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 2fd714722..337055dc2 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -44,6 +44,7 @@ from .base import BaseLLM import httpx # type: ignore from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from litellm.types.llms.bedrock import * +import urllib.parse class AmazonCohereChatConfig: @@ -524,6 +525,16 @@ class BedrockLLM(BaseLLM): return model_response + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") + def completion( self, model: str, @@ -552,7 +563,12 @@ class BedrockLLM(BaseLLM): ## SETUP ## stream = optional_params.pop("stream", None) - modelId = optional_params.pop("model_id", None) or model + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + provider = model.split(".")[0] ## CREDENTIALS ## diff --git a/litellm/main.py b/litellm/main.py index b8d15942b..37fc1db8f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2099,6 +2099,7 @@ def completion( extra_headers=extra_headers, timeout=timeout, acompletion=acompletion, + client=client, ) if optional_params.get("stream", False): ## LOGGING diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index f0a0084b8..6ffc1a4c4 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -13,6 +13,8 @@ import pytest import litellm from litellm import embedding, completion, completion_cost, Timeout, ModelResponse from litellm import RateLimitError +from litellm.llms.custom_httpx.http_handler import HTTPHandler +from unittest.mock import patch, AsyncMock, Mock # litellm.num_retries = 3 litellm.cache = None @@ -509,13 +511,28 @@ def test_bedrock_ptu(): Reference: https://github.com/BerriAI/litellm/issues/3805 """ + client = HTTPHandler() - from openai.types.chat import ChatCompletion + with patch.object(client, "post", new=Mock()) as mock_client_post: + litellm.set_verbose = True + from openai.types.chat import ChatCompletion - response = litellm.completion( - model="bedrock/amazon.my-incorrect-model", - messages=[{"role": "user", "content": "What's AWS?"}], - model_id="amazon.titan-text-lite-v1", - ) + model_id = ( + "arn:aws:bedrock:us-west-2:888602223428:provisioned-model/8fxff74qyhs3" + ) + try: + response = litellm.completion( + model="bedrock/anthropic.claude-instant-v1", + messages=[{"role": "user", "content": "What's AWS?"}], + model_id=model_id, + client=client, + ) + except Exception as e: + pass - ChatCompletion.model_validate(response.model_dump(), strict=True) + assert "url" in mock_client_post.call_args.kwargs + assert ( + mock_client_post.call_args.kwargs["url"] + == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" + ) + mock_client_post.assert_called_once()