forked from phoenix-oss/llama-stack-mirror
add bedrock distribution code (#358)
* add bedrock distribution code * fix linter error * add bedrock shields support * linter fixes * working bedrock safety * change to return only one violation * remove env var reading * refereshable boto credentials * remove env vars * address raghu's feedback * fix session_ttl passing --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
6ebd553da5
commit
093c9f1987
16 changed files with 429 additions and 135 deletions
15
distributions/bedrock/compose.yaml
Normal file
15
distributions/bedrock/compose.yaml
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
services:
|
||||||
|
llamastack:
|
||||||
|
image: distribution-bedrock
|
||||||
|
volumes:
|
||||||
|
- ~/.llama:/root/.llama
|
||||||
|
- ./run.yaml:/root/llamastack-run-bedrock.yaml
|
||||||
|
ports:
|
||||||
|
- "5000:5000"
|
||||||
|
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-bedrock.yaml"
|
||||||
|
deploy:
|
||||||
|
restart_policy:
|
||||||
|
condition: on-failure
|
||||||
|
delay: 3s
|
||||||
|
max_attempts: 5
|
||||||
|
window: 60s
|
46
distributions/bedrock/run.yaml
Normal file
46
distributions/bedrock/run.yaml
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
version: '2'
|
||||||
|
built_at: '2024-11-01T17:40:45.325529'
|
||||||
|
image_name: local
|
||||||
|
name: bedrock
|
||||||
|
docker_image: null
|
||||||
|
conda_env: local
|
||||||
|
apis:
|
||||||
|
- shields
|
||||||
|
- agents
|
||||||
|
- models
|
||||||
|
- memory
|
||||||
|
- memory_banks
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: bedrock0
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||||
|
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||||
|
aws_session_token: <AWS_SESSION_TOKEN>
|
||||||
|
region_name: <AWS_REGION>
|
||||||
|
memory:
|
||||||
|
- provider_id: meta0
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
safety:
|
||||||
|
- provider_id: bedrock0
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||||
|
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||||
|
aws_session_token: <AWS_SESSION_TOKEN>
|
||||||
|
region_name: <AWS_REGION>
|
||||||
|
agents:
|
||||||
|
- provider_id: meta0
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ~/.llama/runtime/kvstore.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta0
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Bedrock Distribution
|
||||||
|
|
||||||
|
### Connect to a Llama Stack Bedrock Endpoint
|
||||||
|
- You may connect to Amazon Bedrock APIs for running LLM inference
|
||||||
|
|
||||||
|
The `llamastack/distribution-bedrock` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
|
||||||
|
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|
|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- |
|
||||||
|
| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference |
|
||||||
|
|
||||||
|
|
||||||
|
### Docker: Start the Distribution (Single Node CPU)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> This assumes you have valid AWS credentials configured with access to Amazon Bedrock.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cd distributions/bedrock && docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g.
|
||||||
|
```
|
||||||
|
inference:
|
||||||
|
- provider_id: bedrock0
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||||
|
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||||
|
aws_session_token: <AWS_SESSION_TOKEN>
|
||||||
|
region_name: <AWS_REGION>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Conda llama stack run (Single Node CPU)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template bedrock --image-type conda
|
||||||
|
# -- modify run.yaml with valid AWS credentials
|
||||||
|
llama stack run ./run.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### (Optional) Update Model Serving Configuration
|
||||||
|
|
||||||
|
Use `llama-stack-client models list` to check the available models served by Amazon Bedrock.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ llama-stack-client models list
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| identifier | llama_model | provider_id | metadata |
|
||||||
|
+==============================+==============================+===============+============+
|
||||||
|
| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} |
|
||||||
|
+------------------------------+------------------------------+---------------+------------+
|
||||||
|
```
|
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
async def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -48,5 +48,5 @@ class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Shields(Protocol):
|
||||||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/get", method="GET")
|
@webmethod(route="/shields/get", method="GET")
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/register", method="POST")
|
@webmethod(route="/shields/register", method="POST")
|
||||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||||
|
|
|
@ -154,12 +154,12 @@ class SafetyRouter(Safety):
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
identifier: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
return await self.routing_table.get_provider_impl(identifier).run_shield(
|
||||||
shield_type=shield_type,
|
identifier=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
return await self.get_all_with_type("shield")
|
return await self.get_all_with_type("shield")
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
|
||||||
return await self.get_object_by_identifier(shield_type)
|
return await self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||||
await self.register_object(shield)
|
await self.register_object(shield)
|
||||||
|
|
|
@ -6,9 +6,7 @@
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
import boto3
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from botocore.config import Config
|
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -16,7 +14,9 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||||
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
BEDROCK_SUPPORTED_MODELS = {
|
||||||
|
@ -34,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = _create_bedrock_client(config)
|
self._client = create_bedrock_client(config)
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -437,43 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|
||||||
retries_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
total_max_attempts=config.total_max_attempts,
|
|
||||||
mode=config.retry_mode,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
config_args = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
region_name=config.region_name,
|
|
||||||
retries=retries_config if retries_config else None,
|
|
||||||
connect_timeout=config.connect_timeout,
|
|
||||||
read_timeout=config.read_timeout,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
boto3_config = Config(**config_args)
|
|
||||||
|
|
||||||
session_args = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
|
||||||
aws_session_token=config.aws_session_token,
|
|
||||||
region_name=config.region_name,
|
|
||||||
profile_name=config.profile_name,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
|
||||||
|
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
|
||||||
|
|
|
@ -3,53 +3,12 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BedrockConfig(BaseModel):
|
class BedrockConfig(BedrockBaseConfig):
|
||||||
aws_access_key_id: Optional[str] = Field(
|
pass
|
||||||
default=None,
|
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
|
||||||
)
|
|
||||||
aws_secret_access_key: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
|
||||||
)
|
|
||||||
aws_session_token: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
|
||||||
)
|
|
||||||
region_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
|
||||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
|
||||||
)
|
|
||||||
profile_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The profile name that contains credentials to use."
|
|
||||||
"Default use environment variable: AWS_PROFILE",
|
|
||||||
)
|
|
||||||
total_max_attempts: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
|
||||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
|
||||||
)
|
|
||||||
retry_mode: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="A string representing the type of retries Boto3 will perform."
|
|
||||||
"Default use environment variable: AWS_RETRY_MODE",
|
|
||||||
)
|
|
||||||
connect_timeout: Optional[float] = Field(
|
|
||||||
default=60,
|
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
|
||||||
"The default is 60 seconds.",
|
|
||||||
)
|
|
||||||
read_timeout: Optional[float] = Field(
|
|
||||||
default=60,
|
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
|
||||||
"The default is 60 seconds.",
|
|
||||||
)
|
|
||||||
|
|
|
@ -9,11 +9,10 @@ import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
@ -28,17 +27,13 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
if not config.aws_profile:
|
|
||||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.registered_shields = []
|
self.registered_shields = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
print(f"initializing with profile --- > {self.config}")
|
self.bedrock_runtime_client = create_bedrock_client(self.config)
|
||||||
self.boto_client = boto3.Session(
|
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
|
||||||
profile_name=self.config.aws_profile
|
|
||||||
).client("bedrock-runtime")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||||
|
|
||||||
|
@ -49,19 +44,28 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
raise NotImplementedError(
|
response = self.bedrock_client.list_guardrails()
|
||||||
"""
|
shields = []
|
||||||
`list_shields` not implemented; this should read all guardrails from
|
for guardrail in response["guardrails"]:
|
||||||
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
# populate the shield def with the guardrail id and version
|
||||||
"""
|
shield_def = ShieldDef(
|
||||||
)
|
identifier=guardrail["id"],
|
||||||
|
shield_type=ShieldType.generic_content_shield.value,
|
||||||
|
params={
|
||||||
|
"guardrailIdentifier": guardrail["id"],
|
||||||
|
"guardrailVersion": guardrail["version"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.registered_shields.append(shield_def)
|
||||||
|
shields.append(shield_def)
|
||||||
|
return shields
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
|
@ -88,7 +92,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.boto_client.apply_guardrail(
|
response = self.bedrock_runtime_client.apply_guardrail(
|
||||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||||
guardrailVersion=shield_params["guardrailVersion"],
|
guardrailVersion=shield_params["guardrailVersion"],
|
||||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||||
|
@ -104,10 +108,12 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
metadata = dict(assessment)
|
metadata = dict(assessment)
|
||||||
|
|
||||||
return SafetyViolation(
|
return RunShieldResponse(
|
||||||
user_message=user_message,
|
violation=SafetyViolation(
|
||||||
violation_level=ViolationLevel.ERROR,
|
user_message=user_message,
|
||||||
metadata=metadata,
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return RunShieldResponse()
|
||||||
|
|
|
@ -4,13 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyConfig(BaseModel):
|
@json_schema_type
|
||||||
"""Configuration information for a guardrail that you want to use in the request."""
|
class BedrockSafetyConfig(BedrockBaseConfig):
|
||||||
|
pass
|
||||||
aws_profile: str = Field(
|
|
||||||
default="default",
|
|
||||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
|
||||||
)
|
|
||||||
|
|
|
@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
||||||
]
|
]
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
model = shield_def.params.get("model", "llama_guard")
|
model = shield_def.params.get("model", "llama_guard")
|
||||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||||
|
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(
|
||||||
self, messages: List[Message], shield_types: List[str]
|
self, messages: List[Message], identifiers: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.safety_api.run_shield(
|
self.safety_api.run_shield(
|
||||||
shield_type=shield_type,
|
identifier=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for shield_type in shield_types
|
for identifier in identifiers
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for shield_type, response in zip(shield_types, responses):
|
for identifier, response in zip(identifiers, responses):
|
||||||
if not response.violation:
|
if not response.violation:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
||||||
raise SafetyException(violation)
|
raise SafetyException(violation)
|
||||||
elif violation.violation_level == ViolationLevel.WARN:
|
elif violation.violation_level == ViolationLevel.WARN:
|
||||||
cprint(
|
cprint(
|
||||||
f"[Warn]{shield_type} raised a warning",
|
f"[Warn]{identifier} raised a warning",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
76
llama_stack/providers/utils/bedrock/client.py
Normal file
76
llama_stack/providers/utils/bedrock/client.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.client import BaseClient
|
||||||
|
from botocore.config import Config
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
|
||||||
|
RefreshableBotoSession,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_bedrock_client(
|
||||||
|
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
|
||||||
|
) -> BaseClient:
|
||||||
|
"""Creates a boto3 client for Bedrock services with the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The Bedrock configuration containing AWS credentials and settings
|
||||||
|
service_name: The AWS service name to create client for (default: "bedrock-runtime")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured boto3 client
|
||||||
|
"""
|
||||||
|
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||||
|
retries_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
total_max_attempts=config.total_max_attempts,
|
||||||
|
mode=config.retry_mode,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
config_args = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
region_name=config.region_name,
|
||||||
|
retries=retries_config if retries_config else None,
|
||||||
|
connect_timeout=config.connect_timeout,
|
||||||
|
read_timeout=config.read_timeout,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
|
session_args = {
|
||||||
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
|
"aws_secret_access_key": config.aws_secret_access_key,
|
||||||
|
"aws_session_token": config.aws_session_token,
|
||||||
|
"region_name": config.region_name,
|
||||||
|
"profile_name": config.profile_name,
|
||||||
|
"session_ttl": config.session_ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
return boto3_session.client(service_name, config=boto3_config)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
RefreshableBotoSession(
|
||||||
|
region_name=config.region_name,
|
||||||
|
profile_name=config.profile_name,
|
||||||
|
session_ttl=config.session_ttl,
|
||||||
|
)
|
||||||
|
.refreshable_session()
|
||||||
|
.client(service_name)
|
||||||
|
)
|
59
llama_stack/providers/utils/bedrock/config.py
Normal file
59
llama_stack/providers/utils/bedrock/config.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BedrockBaseConfig(BaseModel):
|
||||||
|
aws_access_key_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
|
)
|
||||||
|
aws_secret_access_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||||
|
)
|
||||||
|
aws_session_token: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||||
|
)
|
||||||
|
region_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||||
|
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||||
|
)
|
||||||
|
profile_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The profile name that contains credentials to use."
|
||||||
|
"Default use environment variable: AWS_PROFILE",
|
||||||
|
)
|
||||||
|
total_max_attempts: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||||
|
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||||
|
)
|
||||||
|
retry_mode: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="A string representing the type of retries Boto3 will perform."
|
||||||
|
"Default use environment variable: AWS_RETRY_MODE",
|
||||||
|
)
|
||||||
|
connect_timeout: Optional[float] = Field(
|
||||||
|
default=60,
|
||||||
|
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||||
|
"The default is 60 seconds.",
|
||||||
|
)
|
||||||
|
read_timeout: Optional[float] = Field(
|
||||||
|
default=60,
|
||||||
|
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||||
|
"The default is 60 seconds.",
|
||||||
|
)
|
||||||
|
session_ttl: Optional[int] = Field(
|
||||||
|
default=3600,
|
||||||
|
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||||
|
)
|
116
llama_stack/providers/utils/bedrock/refreshable_boto_session.py
Normal file
116
llama_stack/providers/utils/bedrock/refreshable_boto_session.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from time import time
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from boto3 import Session
|
||||||
|
from botocore.credentials import RefreshableCredentials
|
||||||
|
from botocore.session import get_session
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshableBotoSession:
|
||||||
|
"""
|
||||||
|
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
session = RefreshableBotoSession().refreshable_session()
|
||||||
|
|
||||||
|
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
region_name: str = None,
|
||||||
|
profile_name: str = None,
|
||||||
|
sts_arn: str = None,
|
||||||
|
session_name: str = None,
|
||||||
|
session_ttl: int = 30000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize `RefreshableBotoSession`
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
region_name : str (optional)
|
||||||
|
Default region when creating a new connection.
|
||||||
|
|
||||||
|
profile_name : str (optional)
|
||||||
|
The name of a profile to use.
|
||||||
|
|
||||||
|
sts_arn : str (optional)
|
||||||
|
The role arn to sts before creating a session.
|
||||||
|
|
||||||
|
session_name : str (optional)
|
||||||
|
An identifier for the assumed role session. (required when `sts_arn` is given)
|
||||||
|
|
||||||
|
session_ttl : int (optional)
|
||||||
|
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
|
||||||
|
50 minutes by default which is before the default role expiration of 1 hour
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.region_name = region_name
|
||||||
|
self.profile_name = profile_name
|
||||||
|
self.sts_arn = sts_arn
|
||||||
|
self.session_name = session_name or uuid4().hex
|
||||||
|
self.session_ttl = session_ttl
|
||||||
|
|
||||||
|
def __get_session_credentials(self):
|
||||||
|
"""
|
||||||
|
Get session credentials
|
||||||
|
"""
|
||||||
|
session = Session(region_name=self.region_name, profile_name=self.profile_name)
|
||||||
|
|
||||||
|
# if sts_arn is given, get credential by assuming the given role
|
||||||
|
if self.sts_arn:
|
||||||
|
sts_client = session.client(
|
||||||
|
service_name="sts", region_name=self.region_name
|
||||||
|
)
|
||||||
|
response = sts_client.assume_role(
|
||||||
|
RoleArn=self.sts_arn,
|
||||||
|
RoleSessionName=self.session_name,
|
||||||
|
DurationSeconds=self.session_ttl,
|
||||||
|
).get("Credentials")
|
||||||
|
|
||||||
|
credentials = {
|
||||||
|
"access_key": response.get("AccessKeyId"),
|
||||||
|
"secret_key": response.get("SecretAccessKey"),
|
||||||
|
"token": response.get("SessionToken"),
|
||||||
|
"expiry_time": response.get("Expiration").isoformat(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
session_credentials = session.get_credentials().get_frozen_credentials()
|
||||||
|
credentials = {
|
||||||
|
"access_key": session_credentials.access_key,
|
||||||
|
"secret_key": session_credentials.secret_key,
|
||||||
|
"token": session_credentials.token,
|
||||||
|
"expiry_time": datetime.datetime.fromtimestamp(
|
||||||
|
time() + self.session_ttl, datetime.timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def refreshable_session(self) -> Session:
|
||||||
|
"""
|
||||||
|
Get refreshable boto3 session.
|
||||||
|
"""
|
||||||
|
# Get refreshable credentials
|
||||||
|
refreshable_credentials = RefreshableCredentials.create_from_metadata(
|
||||||
|
metadata=self.__get_session_credentials(),
|
||||||
|
refresh_using=self.__get_session_credentials,
|
||||||
|
method="sts-assume-role",
|
||||||
|
)
|
||||||
|
|
||||||
|
# attach refreshable credentials current session
|
||||||
|
session = get_session()
|
||||||
|
session._credentials = refreshable_credentials
|
||||||
|
session.set_config_variable("region", self.region_name)
|
||||||
|
autorefresh_session = Session(botocore_session=session)
|
||||||
|
|
||||||
|
return autorefresh_session
|
Loading…
Add table
Add a link
Reference in a new issue