From c649bd9bdf84400a5a215584ddb6aea85d80e50b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 6 Nov 2024 14:16:37 -0800 Subject: [PATCH] address raghu's feedback --- .../adapters/inference/bedrock/bedrock.py | 53 +------------- .../adapters/inference/bedrock/config.py | 49 +------------ .../adapters/safety/bedrock/bedrock.py | 34 +-------- .../adapters/safety/bedrock/config.py | 32 ++------ llama_stack/providers/utils/bedrock/client.py | 73 +++++++++++++++++++ llama_stack/providers/utils/bedrock/config.py | 55 ++++++++++++++ .../{ => bedrock}/refreshable_boto_session.py | 2 +- 7 files changed, 144 insertions(+), 154 deletions(-) create mode 100644 llama_stack/providers/utils/bedrock/client.py create mode 100644 llama_stack/providers/utils/bedrock/config.py rename llama_stack/providers/utils/{ => bedrock}/refreshable_boto_session.py (99%) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index cc149084a..87b374de1 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -6,9 +6,7 @@ from typing import * # noqa: F403 -import boto3 from botocore.client import BaseClient -from botocore.config import Config from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -18,7 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession +from llama_stack.providers.utils.bedrock.client import create_bedrock_client BEDROCK_SUPPORTED_MODELS = { @@ -36,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) self._config = config - self._client = _create_bedrock_client(config) + self._client = create_bedrock_client(config) self.formatter = ChatFormat(Tokenizer.get_instance()) @property @@ -439,50 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() - - -def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - if config.aws_access_key_id and config.aws_secret_access_key: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } - - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } - - boto3_config = Config(**config_args) - - session_args = { - "aws_access_key_id": config.aws_access_key_id, - "aws_secret_access_key": config.aws_secret_access_key, - "aws_session_token": config.aws_session_token, - "region_name": config.region_name, - "profile_name": config.profile_name, - } - - # Remove None values - session_args = {k: v for k, v in session_args.items() if v is not None} - - boto3_session = boto3.session.Session(**session_args) - return boto3_session.client("bedrock-runtime", config=boto3_config) - else: - return ( - RefreshableBotoSession( - region_name=config.region_name, profile_name=config.profile_name - ) - .refreshable_session() - .client("bedrock-runtime") - ) diff --git a/llama_stack/providers/adapters/inference/bedrock/config.py b/llama_stack/providers/adapters/inference/bedrock/config.py index 72d2079b9..8e194700c 100644 --- a/llama_stack/providers/adapters/inference/bedrock/config.py +++ b/llama_stack/providers/adapters/inference/bedrock/config.py @@ -3,53 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import * # noqa: F403 from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig @json_schema_type -class BedrockConfig(BaseModel): - aws_access_key_id: Optional[str] = Field( - default=None, - description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", - ) - aws_secret_access_key: Optional[str] = Field( - default=None, - description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", - ) - aws_session_token: Optional[str] = Field( - default=None, - description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", - ) - region_name: Optional[str] = Field( - default=None, - description="The default AWS Region to use, for example, us-west-1 or us-west-2." - "Default use environment variable: AWS_DEFAULT_REGION", - ) - profile_name: Optional[str] = Field( - default=None, - description="The profile name that contains credentials to use." - "Default use environment variable: AWS_PROFILE", - ) - total_max_attempts: Optional[int] = Field( - default=None, - description="An integer representing the maximum number of attempts that will be made for a single request, " - "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", - ) - retry_mode: Optional[str] = Field( - default=None, - description="A string representing the type of retries Boto3 will perform." - "Default use environment variable: AWS_RETRY_MODE", - ) - connect_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " - "The default is 60 seconds.", - ) - read_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." - "The default is 60 seconds.", - ) +class BedrockConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index bafc27003..e14dbd2a4 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -9,12 +9,10 @@ import logging from typing import Any, Dict, List -import boto3 - from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ShieldsProtocolPrivate -from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession +from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig @@ -27,30 +25,6 @@ BEDROCK_SUPPORTED_SHIELDS = [ ] -def _create_bedrock_client(config: BedrockSafetyConfig, name: str): - if config.aws_access_key_id and config.aws_secret_access_key: - session_args = { - "aws_access_key_id": config.aws_access_key_id, - "aws_secret_access_key": config.aws_secret_access_key, - "aws_session_token": config.aws_session_token, - "region_name": config.region_name, - "profile_name": config.profile_name, - } - # Remove None values - session_args = {k: v for k, v in session_args.items() if v is not None} - - boto3_session = boto3.session.Session(**session_args) - return boto3_session.client(name) - else: - return ( - RefreshableBotoSession( - region_name=config.region_name, profile_name=config.profile_name - ) - .refreshable_session() - .client(name) - ) - - class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config @@ -58,10 +32,8 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): async def initialize(self) -> None: try: - self.bedrock_runtime_client = _create_bedrock_client( - self.config, "bedrock-runtime" - ) - self.bedrock_client = _create_bedrock_client(self.config, "bedrock") + self.bedrock_runtime_client = create_bedrock_client(self.config) + self.bedrock_client = create_bedrock_client(self.config, "bedrock") except Exception as e: raise RuntimeError("Error initializing BedrockSafetyAdapter") from e diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py index afa83f366..8c61decf3 100644 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -4,32 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional -from pydantic import BaseModel, Field +from llama_models.schema_utils import json_schema_type + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig -class BedrockSafetyConfig(BaseModel): - """Configuration information for a guardrail that you want to use in the request.""" - - aws_access_key_id: Optional[str] = Field( - default=None, - description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", - ) - aws_secret_access_key: Optional[str] = Field( - default=None, - description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", - ) - aws_session_token: Optional[str] = Field( - default=None, - description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", - ) - region_name: Optional[str] = Field( - default=None, - description="The default AWS Region to use, for example, us-west-1 or us-west-2." - "Default use environment variable: AWS_DEFAULT_REGION", - ) - profile_name: str = Field( - default="default", - description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", - ) +@json_schema_type +class BedrockSafetyConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/utils/bedrock/client.py b/llama_stack/providers/utils/bedrock/client.py new file mode 100644 index 000000000..725c248c1 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/client.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import boto3 +from botocore.client import BaseClient +from botocore.config import Config + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +from llama_stack.providers.utils.bedrock.refreshable_boto_session import ( + RefreshableBotoSession, +) + + +def create_bedrock_client( + config: BedrockBaseConfig, service_name: str = "bedrock-runtime" +) -> BaseClient: + """Creates a boto3 client for Bedrock services with the given configuration. + + Args: + config: The Bedrock configuration containing AWS credentials and settings + service_name: The AWS service name to create client for (default: "bedrock-runtime") + + Returns: + A configured boto3 client + """ + if config.aws_access_key_id and config.aws_secret_access_key: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } + + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } + + boto3_config = Config(**config_args) + + session_args = { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + "aws_session_token": config.aws_session_token, + "region_name": config.region_name, + "profile_name": config.profile_name, + } + + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} + + boto3_session = boto3.session.Session(**session_args) + return boto3_session.client(service_name, config=boto3_config) + else: + return ( + RefreshableBotoSession( + region_name=config.region_name, profile_name=config.profile_name + ) + .refreshable_session() + .client(service_name) + ) diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py new file mode 100644 index 000000000..38f1bd756 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/config.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class BedrockBaseConfig(BaseModel): + aws_access_key_id: Optional[str] = Field( + default=None, + description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", + ) + aws_secret_access_key: Optional[str] = Field( + default=None, + description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", + ) + aws_session_token: Optional[str] = Field( + default=None, + description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", + ) + region_name: Optional[str] = Field( + default=None, + description="The default AWS Region to use, for example, us-west-1 or us-west-2." + "Default use environment variable: AWS_DEFAULT_REGION", + ) + profile_name: Optional[str] = Field( + default=None, + description="The profile name that contains credentials to use." + "Default use environment variable: AWS_PROFILE", + ) + total_max_attempts: Optional[int] = Field( + default=None, + description="An integer representing the maximum number of attempts that will be made for a single request, " + "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", + ) + retry_mode: Optional[str] = Field( + default=None, + description="A string representing the type of retries Boto3 will perform." + "Default use environment variable: AWS_RETRY_MODE", + ) + connect_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " + "The default is 60 seconds.", + ) + read_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." + "The default is 60 seconds.", + ) diff --git a/llama_stack/providers/utils/refreshable_boto_session.py b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py similarity index 99% rename from llama_stack/providers/utils/refreshable_boto_session.py rename to llama_stack/providers/utils/bedrock/refreshable_boto_session.py index a3a19dea6..fedffb618 100644 --- a/llama_stack/providers/utils/refreshable_boto_session.py +++ b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py @@ -31,7 +31,7 @@ class RefreshableBotoSession: profile_name: str = None, sts_arn: str = None, session_name: str = None, - session_ttl: int = 3000, + session_ttl: int = 30000, ): """ Initialize `RefreshableBotoSession`