refereshable boto credentials

This commit is contained in:
Dinesh Yeduguru 2024-11-06 06:59:55 -08:00
parent 7d28dc380e
commit 6697ca3d3a
5 changed files with 171 additions and 44 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
BEDROCK_SUPPORTED_MODELS = {
@ -441,38 +442,47 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
if config.aws_access_key_id and config.aws_secret_access_key:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
boto3_config = Config(**config_args)
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
)
.refreshable_session()
.client("bedrock-runtime")
)

View file

@ -14,6 +14,7 @@ import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
from .config import BedrockSafetyConfig
@ -27,19 +28,27 @@ BEDROCK_SUPPORTED_SHIELDS = [
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
if config.aws_access_key_id and config.aws_secret_access_key:
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
)
.refreshable_session()
.client(name)
)
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):