mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
add bedrock distribution code (#358)
* add bedrock distribution code * fix linter error * add bedrock shields support * linter fixes * working bedrock safety * change to return only one violation * remove env var reading * refereshable boto credentials * remove env vars * address raghu's feedback * fix session_ttl passing --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
6ebd553da5
commit
093c9f1987
16 changed files with 429 additions and 135 deletions
15
distributions/bedrock/compose.yaml
Normal file
15
distributions/bedrock/compose.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: distribution-bedrock
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-bedrock.yaml
|
||||
ports:
|
||||
- "5000:5000"
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-bedrock.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
46
distributions/bedrock/run.yaml
Normal file
46
distributions/bedrock/run.yaml
Normal file
|
@ -0,0 +1,46 @@
|
|||
version: '2'
|
||||
built_at: '2024-11-01T17:40:45.325529'
|
||||
image_name: local
|
||||
name: bedrock
|
||||
docker_image: null
|
||||
conda_env: local
|
||||
apis:
|
||||
- shields
|
||||
- agents
|
||||
- models
|
||||
- memory
|
||||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: bedrock0
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||
aws_session_token: <AWS_SESSION_TOKEN>
|
||||
region_name: <AWS_REGION>
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
safety:
|
||||
- provider_id: bedrock0
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||
aws_session_token: <AWS_SESSION_TOKEN>
|
||||
region_name: <AWS_REGION>
|
||||
agents:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
|
@ -0,0 +1,58 @@
|
|||
# Bedrock Distribution
|
||||
|
||||
### Connect to a Llama Stack Bedrock Endpoint
|
||||
- You may connect to Amazon Bedrock APIs for running LLM inference
|
||||
|
||||
The `llamastack/distribution-bedrock` distribution consists of the following provider configurations.
|
||||
|
||||
|
||||
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||
|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- |
|
||||
| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference |
|
||||
|
||||
|
||||
### Docker: Start the Distribution (Single Node CPU)
|
||||
|
||||
> [!NOTE]
|
||||
> This assumes you have valid AWS credentials configured with access to Amazon Bedrock.
|
||||
|
||||
```
|
||||
$ cd distributions/bedrock && docker compose up
|
||||
```
|
||||
|
||||
Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g.
|
||||
```
|
||||
inference:
|
||||
- provider_id: bedrock0
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
aws_access_key_id: <AWS_ACCESS_KEY_ID>
|
||||
aws_secret_access_key: <AWS_SECRET_ACCESS_KEY>
|
||||
aws_session_token: <AWS_SESSION_TOKEN>
|
||||
region_name: <AWS_REGION>
|
||||
```
|
||||
|
||||
### Conda llama stack run (Single Node CPU)
|
||||
|
||||
```bash
|
||||
llama stack build --template bedrock --image-type conda
|
||||
# -- modify run.yaml with valid AWS credentials
|
||||
llama stack run ./run.yaml
|
||||
```
|
||||
|
||||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
Use `llama-stack-client models list` to check the available models served by Amazon Bedrock.
|
||||
|
||||
```
|
||||
$ llama-stack-client models list
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| identifier | llama_model | provider_id | metadata |
|
||||
+==============================+==============================+===============+============+
|
||||
| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
```
|
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
|||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
async def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -48,5 +48,5 @@ class Safety(Protocol):
|
|||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -46,7 +46,7 @@ class Shields(Protocol):
|
|||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
|
|
|
@ -154,12 +154,12 @@ class SafetyRouter(Safety):
|
|||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
return await self.routing_table.get_provider_impl(identifier).run_shield(
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return await self.get_all_with_type("shield")
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(shield_type)
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
|
|
@ -6,9 +6,7 @@
|
|||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
@ -16,7 +14,9 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
|
@ -34,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
self._config = config
|
||||
|
||||
self._client = _create_bedrock_client(config)
|
||||
self._client = create_bedrock_client(config)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
|
@ -437,43 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
|
|
@ -3,53 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import * # noqa: F403
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use."
|
||||
"Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
|
@ -9,11 +9,10 @@ import logging
|
|||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
@ -28,17 +27,13 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
|||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"initializing with profile --- > {self.config}")
|
||||
self.boto_client = boto3.Session(
|
||||
profile_name=self.config.aws_profile
|
||||
).client("bedrock-runtime")
|
||||
self.bedrock_runtime_client = create_bedrock_client(self.config)
|
||||
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
|
@ -49,19 +44,28 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
raise NotImplementedError(
|
||||
"""
|
||||
`list_shields` not implemented; this should read all guardrails from
|
||||
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
||||
"""
|
||||
response = self.bedrock_client.list_guardrails()
|
||||
shields = []
|
||||
for guardrail in response["guardrails"]:
|
||||
# populate the shield def with the guardrail id and version
|
||||
shield_def = ShieldDef(
|
||||
identifier=guardrail["id"],
|
||||
shield_type=ShieldType.generic_content_shield.value,
|
||||
params={
|
||||
"guardrailIdentifier": guardrail["id"],
|
||||
"guardrailVersion": guardrail["version"],
|
||||
},
|
||||
)
|
||||
self.registered_shields.append(shield_def)
|
||||
shields.append(shield_def)
|
||||
return shields
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
@ -88,7 +92,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
|
@ -104,10 +108,12 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return SafetyViolation(
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
return RunShieldResponse()
|
||||
|
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
||||
@json_schema_type
|
||||
class BedrockSafetyConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
|
@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
|||
]
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
|
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
76
llama_stack/providers/utils/bedrock/client.py
Normal file
76
llama_stack/providers/utils/bedrock/client.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
|
||||
RefreshableBotoSession,
|
||||
)
|
||||
|
||||
|
||||
def create_bedrock_client(
|
||||
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
|
||||
) -> BaseClient:
|
||||
"""Creates a boto3 client for Bedrock services with the given configuration.
|
||||
|
||||
Args:
|
||||
config: The Bedrock configuration containing AWS credentials and settings
|
||||
service_name: The AWS service name to create client for (default: "bedrock-runtime")
|
||||
|
||||
Returns:
|
||||
A configured boto3 client
|
||||
"""
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
"aws_session_token": config.aws_session_token,
|
||||
"region_name": config.region_name,
|
||||
"profile_name": config.profile_name,
|
||||
"session_ttl": config.session_ttl,
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client(service_name, config=boto3_config)
|
||||
else:
|
||||
return (
|
||||
RefreshableBotoSession(
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
session_ttl=config.session_ttl,
|
||||
)
|
||||
.refreshable_session()
|
||||
.client(service_name)
|
||||
)
|
59
llama_stack/providers/utils/bedrock/config.py
Normal file
59
llama_stack/providers/utils/bedrock/config.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use."
|
||||
"Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: Optional[int] = Field(
|
||||
default=3600,
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
116
llama_stack/providers/utils/bedrock/refreshable_boto_session.py
Normal file
116
llama_stack/providers/utils/bedrock/refreshable_boto_session.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
|
||||
from boto3 import Session
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
|
||||
|
||||
class RefreshableBotoSession:
|
||||
"""
|
||||
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
|
||||
|
||||
Usage
|
||||
-----
|
||||
session = RefreshableBotoSession().refreshable_session()
|
||||
|
||||
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
region_name: str = None,
|
||||
profile_name: str = None,
|
||||
sts_arn: str = None,
|
||||
session_name: str = None,
|
||||
session_ttl: int = 30000,
|
||||
):
|
||||
"""
|
||||
Initialize `RefreshableBotoSession`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
region_name : str (optional)
|
||||
Default region when creating a new connection.
|
||||
|
||||
profile_name : str (optional)
|
||||
The name of a profile to use.
|
||||
|
||||
sts_arn : str (optional)
|
||||
The role arn to sts before creating a session.
|
||||
|
||||
session_name : str (optional)
|
||||
An identifier for the assumed role session. (required when `sts_arn` is given)
|
||||
|
||||
session_ttl : int (optional)
|
||||
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
|
||||
50 minutes by default which is before the default role expiration of 1 hour
|
||||
"""
|
||||
|
||||
self.region_name = region_name
|
||||
self.profile_name = profile_name
|
||||
self.sts_arn = sts_arn
|
||||
self.session_name = session_name or uuid4().hex
|
||||
self.session_ttl = session_ttl
|
||||
|
||||
def __get_session_credentials(self):
|
||||
"""
|
||||
Get session credentials
|
||||
"""
|
||||
session = Session(region_name=self.region_name, profile_name=self.profile_name)
|
||||
|
||||
# if sts_arn is given, get credential by assuming the given role
|
||||
if self.sts_arn:
|
||||
sts_client = session.client(
|
||||
service_name="sts", region_name=self.region_name
|
||||
)
|
||||
response = sts_client.assume_role(
|
||||
RoleArn=self.sts_arn,
|
||||
RoleSessionName=self.session_name,
|
||||
DurationSeconds=self.session_ttl,
|
||||
).get("Credentials")
|
||||
|
||||
credentials = {
|
||||
"access_key": response.get("AccessKeyId"),
|
||||
"secret_key": response.get("SecretAccessKey"),
|
||||
"token": response.get("SessionToken"),
|
||||
"expiry_time": response.get("Expiration").isoformat(),
|
||||
}
|
||||
else:
|
||||
session_credentials = session.get_credentials().get_frozen_credentials()
|
||||
credentials = {
|
||||
"access_key": session_credentials.access_key,
|
||||
"secret_key": session_credentials.secret_key,
|
||||
"token": session_credentials.token,
|
||||
"expiry_time": datetime.datetime.fromtimestamp(
|
||||
time() + self.session_ttl, datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
|
||||
return credentials
|
||||
|
||||
def refreshable_session(self) -> Session:
|
||||
"""
|
||||
Get refreshable boto3 session.
|
||||
"""
|
||||
# Get refreshable credentials
|
||||
refreshable_credentials = RefreshableCredentials.create_from_metadata(
|
||||
metadata=self.__get_session_credentials(),
|
||||
refresh_using=self.__get_session_credentials,
|
||||
method="sts-assume-role",
|
||||
)
|
||||
|
||||
# attach refreshable credentials current session
|
||||
session = get_session()
|
||||
session._credentials = refreshable_credentials
|
||||
session.set_config_variable("region", self.region_name)
|
||||
autorefresh_session = Session(botocore_session=session)
|
||||
|
||||
return autorefresh_session
|
Loading…
Add table
Add a link
Reference in a new issue