mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-17 07:07:19 +00:00
working bedrock safety
This commit is contained in:
parent
a4fd91fe51
commit
df76c9b484
7 changed files with 63 additions and 44 deletions
|
@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import os
|
||||
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
|
@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
total_max_attempts=os.environ.get(
|
||||
"AWS_MAX_ATTEMPTS", config.total_max_attempts
|
||||
),
|
||||
mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode),
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
|
@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
"aws_access_key_id": os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||
),
|
||||
"aws_secret_access_key": os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||
),
|
||||
"aws_session_token": os.environ.get(
|
||||
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||
),
|
||||
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue