mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fix: AWS Bedrock inference profile ID conversion for region-specific endpoints (#3386)
Fixes #3370 AWS switched to requiring region-prefixed inference profile IDs instead of foundation model IDs for on-demand throughput. This was causing ValidationException errors. Added auto-detection based on boto3 client region to convert model IDs like meta.llama3-1-70b-instruct-v1:0 to us.meta.llama3-1-70b-instruct-v1:0 depending on the detected region. Also handles edge cases like ARNs, case insensitive regions, and None regions. Tested with this request. ```json { "model_id": "meta.llama3-1-8b-instruct-v1:0", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "tell me a riddle" } ], "sampling_params": { "strategy": { "type": "top_p", "temperature": 0.7, "top_p": 0.9 }, "max_tokens": 512 } } ``` <img width="1488" height="878" alt="image" src="https://github.com/user-attachments/assets/0d61beec-3869-4a31-8f37-9f554c280b88" />
This commit is contained in:
parent
8e05c68d15
commit
2838d5a20f
2 changed files with 102 additions and 2 deletions
|
@ -53,6 +53,43 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
REGION_PREFIX_MAP = {
|
||||||
|
"us": "us.",
|
||||||
|
"eu": "eu.",
|
||||||
|
"ap": "ap.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_region_prefix(region: str | None) -> str:
|
||||||
|
# AWS requires region prefixes for inference profiles
|
||||||
|
if region is None:
|
||||||
|
return "us." # default to US when we don't know
|
||||||
|
|
||||||
|
# Handle case insensitive region matching
|
||||||
|
region_lower = region.lower()
|
||||||
|
for prefix in REGION_PREFIX_MAP:
|
||||||
|
if region_lower.startswith(f"{prefix}-"):
|
||||||
|
return REGION_PREFIX_MAP[prefix]
|
||||||
|
|
||||||
|
# Fallback to US for anything we don't recognize
|
||||||
|
return "us."
|
||||||
|
|
||||||
|
|
||||||
|
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||||
|
# Return ARNs unchanged
|
||||||
|
if model_id.startswith("arn:"):
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
# Return inference profile IDs that already have regional prefixes
|
||||||
|
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
# Default to US East when no region is provided
|
||||||
|
if region is None:
|
||||||
|
region = "us-east-1"
|
||||||
|
|
||||||
|
return _get_region_prefix(region) + model_id
|
||||||
|
|
||||||
|
|
||||||
class BedrockInferenceAdapter(
|
class BedrockInferenceAdapter(
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -166,8 +203,13 @@ class BedrockInferenceAdapter(
|
||||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||||
|
|
||||||
|
# Convert foundation model ID to inference profile ID
|
||||||
|
region_name = self.client.meta.region_name
|
||||||
|
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"modelId": bedrock_model,
|
"modelId": inference_profile_id,
|
||||||
"body": json.dumps(
|
"body": json.dumps(
|
||||||
{
|
{
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
@ -185,6 +227,11 @@ class BedrockInferenceAdapter(
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
# Convert foundation model ID to inference profile ID
|
||||||
|
region_name = self.client.meta.region_name
|
||||||
|
inference_profile_id = _to_inference_profile_id(model.provider_resource_id, region_name)
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for content in contents:
|
for content in contents:
|
||||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||||
|
@ -193,7 +240,7 @@ class BedrockInferenceAdapter(
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=model.provider_resource_id,
|
modelId=inference_profile_id,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
53
tests/unit/providers/test_bedrock.py
Normal file
53
tests/unit/providers/test_bedrock.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.bedrock import (
|
||||||
|
_get_region_prefix,
|
||||||
|
_to_inference_profile_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_region_prefixes():
|
||||||
|
assert _get_region_prefix("us-east-1") == "us."
|
||||||
|
assert _get_region_prefix("eu-west-1") == "eu."
|
||||||
|
assert _get_region_prefix("ap-south-1") == "ap."
|
||||||
|
assert _get_region_prefix("ca-central-1") == "us."
|
||||||
|
|
||||||
|
# Test case insensitive
|
||||||
|
assert _get_region_prefix("US-EAST-1") == "us."
|
||||||
|
assert _get_region_prefix("EU-WEST-1") == "eu."
|
||||||
|
assert _get_region_prefix("Ap-South-1") == "ap."
|
||||||
|
|
||||||
|
# Test None region
|
||||||
|
assert _get_region_prefix(None) == "us."
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_id_conversion():
|
||||||
|
# Basic conversion
|
||||||
|
assert (
|
||||||
|
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Already has prefix
|
||||||
|
assert (
|
||||||
|
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
|
||||||
|
== "us.meta.llama3-1-70b-instruct-v1:0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ARN should be returned unchanged
|
||||||
|
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
|
||||||
|
assert _to_inference_profile_id(arn, "us-east-1") == arn
|
||||||
|
|
||||||
|
# ARN should be returned unchanged even without region
|
||||||
|
assert _to_inference_profile_id(arn) == arn
|
||||||
|
|
||||||
|
# Optional region parameter defaults to us-east-1
|
||||||
|
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||||
|
|
||||||
|
# Different regions work with optional parameter
|
||||||
|
assert (
|
||||||
|
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue