add bedrock distribution code (#358)

* add bedrock distribution code

* fix linter error

* add bedrock shields support

* linter fixes

* working bedrock safety

* change to return only one violation

* remove env var reading

* refereshable boto credentials

* remove env vars

* address raghu's feedback

* fix session_ttl passing

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-06 14:39:11 -08:00 committed by GitHub
parent 6ebd553da5
commit 093c9f1987
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 429 additions and 135 deletions

View file

@ -6,9 +6,7 @@
from typing import * # noqa: F403
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
@ -16,7 +14,9 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
BEDROCK_SUPPORTED_MODELS = {
@ -34,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
self._config = config
self._client = _create_bedrock_client(config)
self._client = create_bedrock_client(config)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
@ -437,43 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)

View file

@ -3,53 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
@json_schema_type
class BedrockConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
default=None,
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: Optional[str] = Field(
default=None,
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: Optional[str] = Field(
default=None,
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: Optional[str] = Field(
default=None,
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: Optional[str] = Field(
default=None,
description="The profile name that contains credentials to use."
"Default use environment variable: AWS_PROFILE",
)
total_max_attempts: Optional[int] = Field(
default=None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: Optional[str] = Field(
default=None,
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: Optional[float] = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -9,11 +9,10 @@ import logging
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
@ -28,17 +27,13 @@ BEDROCK_SUPPORTED_SHIELDS = [
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
print(f"initializing with profile --- > {self.config}")
self.boto_client = boto3.Session(
profile_name=self.config.aws_profile
).client("bedrock-runtime")
self.bedrock_runtime_client = create_bedrock_client(self.config)
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
except Exception as e:
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
@ -49,19 +44,28 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
raise NotImplementedError(
"""
`list_shields` not implemented; this should read all guardrails from
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
"""
)
response = self.bedrock_client.list_guardrails()
shields = []
for guardrail in response["guardrails"]:
# populate the shield def with the guardrail id and version
shield_def = ShieldDef(
identifier=guardrail["id"],
shield_type=ShieldType.generic_content_shield.value,
params={
"guardrailIdentifier": guardrail["id"],
"guardrailVersion": guardrail["version"],
},
)
self.registered_shields.append(shield_def)
shields.append(shield_def)
return shields
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -88,7 +92,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
@ -104,10 +108,12 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return None
return RunShieldResponse()

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: str = Field(
default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)
@json_schema_type
class BedrockSafetyConfig(BedrockBaseConfig):
pass

View file

@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
]
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
model = shield_def.params.get("model", "llama_guard")
if model not in TOGETHER_SHIELD_MODEL_MAP: