From 7176338ca6bfe6d2f6185e9c5a962b1c4c71bca6 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 5 Nov 2024 08:02:28 -0800 Subject: [PATCH] add bedrock shields support --- distributions/bedrock/run.yaml | 10 +++- .../remote_hosted_distro/bedrock.md | 58 +++++++++++++++++++ .../adapters/safety/bedrock/bedrock.py | 45 +++++++++----- .../adapters/safety/bedrock/config.py | 21 ++++++- 4 files changed, 116 insertions(+), 18 deletions(-) create mode 100644 docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md diff --git a/distributions/bedrock/run.yaml b/distributions/bedrock/run.yaml index 0c6d22474..bd9a89566 100644 --- a/distributions/bedrock/run.yaml +++ b/distributions/bedrock/run.yaml @@ -26,9 +26,13 @@ providers: provider_type: meta-reference config: {} safety: - - provider_id: meta0 - provider_type: meta-reference - config: {} + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: agents: - provider_id: meta0 provider_type: meta-reference diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md new file mode 100644 index 000000000..28691d4e3 --- /dev/null +++ b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md @@ -0,0 +1,58 @@ +# Bedrock Distribution + +### Connect to a Llama Stack Bedrock Endpoint +- You may connect to Amazon Bedrock APIs for running LLM inference + +The `llamastack/distribution-bedrock` distribution consists of the following provider configurations. + + +| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | +|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- | +| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference | + + +### Docker: Start the Distribution (Single Node CPU) + +> [!NOTE] +> This assumes you have valid AWS credentials configured with access to Amazon Bedrock. + +``` +$ cd distributions/bedrock && docker compose up +``` + +Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g. +``` +inference: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: +``` + +### Conda llama stack run (Single Node CPU) + +```bash +llama stack build --template bedrock --image-type conda +# -- modify run.yaml with valid AWS credentials +llama stack run ./run.yaml +``` + +### (Optional) Update Model Serving Configuration + +Use `llama-stack-client models list` to check the available models served by Amazon Bedrock. + +``` +$ llama-stack-client models list ++------------------------------+------------------------------+---------------+------------+ +| identifier | llama_model | provider_id | metadata | ++==============================+==============================+===============+============+ +| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +``` diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 3203e36f4..f36515471 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -25,20 +25,33 @@ BEDROCK_SUPPORTED_SHIELDS = [ ShieldType.generic_content_shield.value, ] +def _create_bedrock_client(config: BedrockSafetyConfig, name: str) : + session_args = { + k: v + for k, v in dict( + aws_access_key_id=config.aws_access_key_id, + aws_secret_access_key=config.aws_secret_access_key, + aws_session_token=config.aws_session_token, + region_name=config.region_name, + profile_name=config.profile_name, + ).items() + if v is not None + } + + boto3_session = boto3.session.Session(**session_args) + + return boto3_session.client(name) + class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: - if not config.aws_profile: - raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config self.registered_shields = [] async def initialize(self) -> None: try: - print(f"initializing with profile --- > {self.config}") - self.boto_client = boto3.Session( - profile_name=self.config.aws_profile - ).client("bedrock-runtime") + self.bedrock_runtime_client = _create_bedrock_client(self.config, "bedrock-runtime") + self.bedrock_client = _create_bedrock_client(self.config, "bedrock") except Exception as e: raise RuntimeError("Error initializing BedrockSafetyAdapter") from e @@ -49,12 +62,18 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): raise ValueError("Registering dynamic shields is not supported") async def list_shields(self) -> List[ShieldDef]: - raise NotImplementedError( - """ - `list_shields` not implemented; this should read all guardrails from - bedrock and populate guardrailId and guardrailVersion in the ShieldDef. - """ - ) + response = self.bedrock_client.list_guardrails() + shields = [] + for guardrail in response["guardrails"]: + # populate the shield def with the guardrail id and version + shield_def = ShieldDef( + identifier=guardrail["id"], + shield_type=ShieldType.generic_content_shield.value, + params={"guardrailIdentifier": guardrail["id"], "guardrailVersion": guardrail["version"]}, + ) + shields.append(shield_def) + return shields + async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None @@ -88,7 +107,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" ) - response = self.boto_client.apply_guardrail( + response = self.bedrock_runtime_client.apply_guardrail( guardrailIdentifier=shield_params["guardrailIdentifier"], guardrailVersion=shield_params["guardrailVersion"], source="OUTPUT", # or 'INPUT' depending on your use case diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py index 2a8585262..7a01d08fb 100644 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -5,12 +5,29 @@ # the root directory of this source tree. from pydantic import BaseModel, Field - +from typing import Optional class BedrockSafetyConfig(BaseModel): """Configuration information for a guardrail that you want to use in the request.""" - aws_profile: str = Field( + aws_access_key_id: Optional[str] = Field( + default=None, + description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", + ) + aws_secret_access_key: Optional[str] = Field( + default=None, + description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", + ) + aws_session_token: Optional[str] = Field( + default=None, + description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", + ) + region_name: Optional[str] = Field( + default=None, + description="The default AWS Region to use, for example, us-west-1 or us-west-2." + "Default use environment variable: AWS_DEFAULT_REGION", + ) + profile_name: str = Field( default="default", description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", )