mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
add bedrock shields support
This commit is contained in:
parent
ccd60dc29d
commit
7176338ca6
4 changed files with 116 additions and 18 deletions
|
@ -26,9 +26,13 @@ providers:
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta0
|
- provider_id: bedrock0
|
||||||
provider_type: meta-reference
|
provider_type: remote::bedrock
|
||||||
config: {}
|
config:
|
||||||
|
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||||
|
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||||
|
aws_session_token: <AWS_SESSION_TOKEN>
|
||||||
|
region_name: <AWS_REGION>
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Bedrock Distribution
|
||||||
|
|
||||||
|
### Connect to a Llama Stack Bedrock Endpoint
|
||||||
|
- You may connect to Amazon Bedrock APIs for running LLM inference
|
||||||
|
|
||||||
|
The `llamastack/distribution-bedrock` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
|
||||||
|
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|
|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- |
|
||||||
|
| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference |
|
||||||
|
|
||||||
|
|
||||||
|
### Docker: Start the Distribution (Single Node CPU)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> This assumes you have valid AWS credentials configured with access to Amazon Bedrock.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cd distributions/bedrock && docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g.
|
||||||
|
```
|
||||||
|
inference:
|
||||||
|
- provider_id: bedrock0
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||||
|
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||||
|
aws_session_token: <AWS_SESSION_TOKEN>
|
||||||
|
region_name: <AWS_REGION>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Conda llama stack run (Single Node CPU)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template bedrock --image-type conda
|
||||||
|
# -- modify run.yaml with valid AWS credentials
|
||||||
|
llama stack run ./run.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### (Optional) Update Model Serving Configuration
|
||||||
|
|
||||||
|
Use `llama-stack-client models list` to check the available models served by Amazon Bedrock.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ llama-stack-client models list
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| identifier | llama_model | provider_id | metadata |
|
||||||
|
+==============================+==============================+===============+============+
|
||||||
|
| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
```
|
|
@ -25,20 +25,33 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
ShieldType.generic_content_shield.value,
|
ShieldType.generic_content_shield.value,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _create_bedrock_client(config: BedrockSafetyConfig, name: str) :
|
||||||
|
session_args = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
aws_access_key_id=config.aws_access_key_id,
|
||||||
|
aws_secret_access_key=config.aws_secret_access_key,
|
||||||
|
aws_session_token=config.aws_session_token,
|
||||||
|
region_name=config.region_name,
|
||||||
|
profile_name=config.profile_name,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
|
||||||
|
return boto3_session.client(name)
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
if not config.aws_profile:
|
|
||||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.registered_shields = []
|
self.registered_shields = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
print(f"initializing with profile --- > {self.config}")
|
self.bedrock_runtime_client = _create_bedrock_client(self.config, "bedrock-runtime")
|
||||||
self.boto_client = boto3.Session(
|
self.bedrock_client = _create_bedrock_client(self.config, "bedrock")
|
||||||
profile_name=self.config.aws_profile
|
|
||||||
).client("bedrock-runtime")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||||
|
|
||||||
|
@ -49,12 +62,18 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
raise NotImplementedError(
|
response = self.bedrock_client.list_guardrails()
|
||||||
"""
|
shields = []
|
||||||
`list_shields` not implemented; this should read all guardrails from
|
for guardrail in response["guardrails"]:
|
||||||
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
# populate the shield def with the guardrail id and version
|
||||||
"""
|
shield_def = ShieldDef(
|
||||||
)
|
identifier=guardrail["id"],
|
||||||
|
shield_type=ShieldType.generic_content_shield.value,
|
||||||
|
params={"guardrailIdentifier": guardrail["id"], "guardrailVersion": guardrail["version"]},
|
||||||
|
)
|
||||||
|
shields.append(shield_def)
|
||||||
|
return shields
|
||||||
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
@ -88,7 +107,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.boto_client.apply_guardrail(
|
response = self.bedrock_runtime_client.apply_guardrail(
|
||||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||||
guardrailVersion=shield_params["guardrailVersion"],
|
guardrailVersion=shield_params["guardrailVersion"],
|
||||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||||
|
|
|
@ -5,12 +5,29 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class BedrockSafetyConfig(BaseModel):
|
class BedrockSafetyConfig(BaseModel):
|
||||||
"""Configuration information for a guardrail that you want to use in the request."""
|
"""Configuration information for a guardrail that you want to use in the request."""
|
||||||
|
|
||||||
aws_profile: str = Field(
|
aws_access_key_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
|
)
|
||||||
|
aws_secret_access_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||||
|
)
|
||||||
|
aws_session_token: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||||
|
)
|
||||||
|
region_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||||
|
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||||
|
)
|
||||||
|
profile_name: str = Field(
|
||||||
default="default",
|
default="default",
|
||||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue