forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): fix bedrock ptu model id str encoding
Fixes https://github.com/BerriAI/litellm/issues/3805
This commit is contained in:
parent
81ca145259
commit
d2e14ca833
3 changed files with 42 additions and 8 deletions
|
@ -44,6 +44,7 @@ from .base import BaseLLM
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||||
from litellm.types.llms.bedrock import *
|
from litellm.types.llms.bedrock import *
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
|
@ -524,6 +525,16 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||||
|
Args:
|
||||||
|
model_id (str): The model ID to encode.
|
||||||
|
Returns:
|
||||||
|
str: The double-encoded model ID.
|
||||||
|
"""
|
||||||
|
return urllib.parse.quote(model_id, safe="")
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -552,7 +563,12 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
modelId = optional_params.pop("model_id", None) or model
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
if modelId is not None:
|
||||||
|
modelId = self.encode_model_id(model_id=modelId)
|
||||||
|
else:
|
||||||
|
modelId = model
|
||||||
|
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
|
|
|
@ -2099,6 +2099,7 @@ def completion(
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -13,6 +13,8 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout, ModelResponse
|
from litellm import embedding, completion, completion_cost, Timeout, ModelResponse
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch, AsyncMock, Mock
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries = 3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -509,13 +511,28 @@ def test_bedrock_ptu():
|
||||||
|
|
||||||
Reference: https://github.com/BerriAI/litellm/issues/3805
|
Reference: https://github.com/BerriAI/litellm/issues/3805
|
||||||
"""
|
"""
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletion
|
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
response = litellm.completion(
|
model_id = (
|
||||||
model="bedrock/amazon.my-incorrect-model",
|
"arn:aws:bedrock:us-west-2:888602223428:provisioned-model/8fxff74qyhs3"
|
||||||
messages=[{"role": "user", "content": "What's AWS?"}],
|
)
|
||||||
model_id="amazon.titan-text-lite-v1",
|
try:
|
||||||
)
|
response = litellm.completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=[{"role": "user", "content": "What's AWS?"}],
|
||||||
|
model_id=model_id,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
ChatCompletion.model_validate(response.model_dump(), strict=True)
|
assert "url" in mock_client_post.call_args.kwargs
|
||||||
|
assert (
|
||||||
|
mock_client_post.call_args.kwargs["url"]
|
||||||
|
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
|
||||||
|
)
|
||||||
|
mock_client_post.assert_called_once()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue