address raghu's feedback

This commit is contained in:
Dinesh Yeduguru 2024-11-06 14:16:37 -08:00
parent 2101cb08c7
commit c649bd9bdf
7 changed files with 144 additions and 154 deletions

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
RefreshableBotoSession,
)
def create_bedrock_client(
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
) -> BaseClient:
"""Creates a boto3 client for Bedrock services with the given configuration.
Args:
config: The Bedrock configuration containing AWS credentials and settings
service_name: The AWS service name to create client for (default: "bedrock-runtime")
Returns:
A configured boto3 client
"""
if config.aws_access_key_id and config.aws_secret_access_key:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(service_name, config=boto3_config)
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
)
.refreshable_session()
.client(service_name)
)

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class BedrockBaseConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
default=None,
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: Optional[str] = Field(
default=None,
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: Optional[str] = Field(
default=None,
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: Optional[str] = Field(
default=None,
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: Optional[str] = Field(
default=None,
description="The profile name that contains credentials to use."
"Default use environment variable: AWS_PROFILE",
)
total_max_attempts: Optional[int] = Field(
default=None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: Optional[str] = Field(
default=None,
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)

View file

@ -31,7 +31,7 @@ class RefreshableBotoSession:
profile_name: str = None,
sts_arn: str = None,
session_name: str = None,
session_ttl: int = 3000,
session_ttl: int = 30000,
):
"""
Initialize `RefreshableBotoSession`