Merge pull request #9123 from omrishiv/8911-fix-model-encoding

Fixes bedrock modelId encoding for Inference Profiles
This commit is contained in:
Krish Dholakia 2025-03-13 10:42:32 -07:00 committed by GitHub
commit cb7cbdff8f
3 changed files with 23 additions and 3 deletions

View file

@ -274,7 +274,7 @@ class BedrockConverseLLM(BaseAWSLLM):
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
modelId = self.encode_model_id(model_id=model)
if stream is True and "ai21" in modelId:
fake_stream = True

View file

@ -0,0 +1,20 @@
import os
import sys
from litellm.llms.bedrock.chat import BedrockConverseLLM
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
import litellm
def test_encode_model_id_with_inference_profile():
"""
Test instance profile is properly encoded when used as a model
"""
test_model = "arn:aws:bedrock:us-east-1:12345678910:application-inference-profile/ujdtmcirjhevpi"
expected_model = "arn%3Aaws%3Abedrock%3Aus-east-1%3A12345678910%3Aapplication-inference-profile%2Fujdtmcirjhevpi"
bedrock_converse_llm = BedrockConverseLLM()
returned_model = bedrock_converse_llm.encode_model_id(test_model)
assert expected_model == returned_model

View file

@ -983,7 +983,7 @@ async def test_bedrock_custom_api_base():
print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
assert (
mock_client_post.call_args.kwargs["url"]
== "https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1/model/anthropic.claude-3-sonnet-20240229-v1:0/converse"
== "https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1/model/anthropic.claude-3-sonnet-20240229-v1%3A0/converse"
)
assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
@ -2382,7 +2382,7 @@ def test_bedrock_cross_region_inference(monkeypatch):
assert (
mock_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/us.meta.llama3-3-70b-instruct-v1:0/converse"
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/us.meta.llama3-3-70b-instruct-v1%3A0/converse"
)