fix(bedrock_httpx.py): fix bedrock ptu model id str encoding

Fixes https://github.com/BerriAI/litellm/issues/3805
This commit is contained in:
Krrish Dholakia 2024-05-25 10:53:27 -07:00
parent 81ca145259
commit d2e14ca833
3 changed files with 42 additions and 8 deletions

View file

@ -44,6 +44,7 @@ from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
class AmazonCohereChatConfig:
@ -524,6 +525,16 @@ class BedrockLLM(BaseLLM):
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def completion(
self,
model: str,
@ -552,7 +563,12 @@ class BedrockLLM(BaseLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None) or model
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##

View file

@ -2099,6 +2099,7 @@ def completion(
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if optional_params.get("stream", False):
## LOGGING

View file

@ -13,6 +13,8 @@ import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout, ModelResponse
from litellm import RateLimitError
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, AsyncMock, Mock
# litellm.num_retries = 3
litellm.cache = None
@ -509,13 +511,28 @@ def test_bedrock_ptu():
Reference: https://github.com/BerriAI/litellm/issues/3805
"""
client = HTTPHandler()
from openai.types.chat import ChatCompletion
with patch.object(client, "post", new=Mock()) as mock_client_post:
litellm.set_verbose = True
from openai.types.chat import ChatCompletion
response = litellm.completion(
model="bedrock/amazon.my-incorrect-model",
messages=[{"role": "user", "content": "What's AWS?"}],
model_id="amazon.titan-text-lite-v1",
)
model_id = (
"arn:aws:bedrock:us-west-2:888602223428:provisioned-model/8fxff74qyhs3"
)
try:
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
messages=[{"role": "user", "content": "What's AWS?"}],
model_id=model_id,
client=client,
)
except Exception as e:
pass
ChatCompletion.model_validate(response.model_dump(), strict=True)
assert "url" in mock_client_post.call_args.kwargs
assert (
mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
)
mock_client_post.assert_called_once()