From a33cafc2fe7f5232875bd45fcb746495258aa16c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 10 Oct 2024 10:19:06 -0700 Subject: [PATCH] Update bedrock --- .../adapters/inference/bedrock/bedrock.py | 100 ++++++++++-------- .../adapters/safety/bedrock/bedrock.py | 23 ++-- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 7f51894bc..22f87ef6b 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -26,57 +26,16 @@ BEDROCK_SUPPORTED_MODELS = { } +# NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): - - @staticmethod - def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } - - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } - - boto3_config = Config(**config_args) - - session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None - } - - boto3_session = boto3.session.Session(**session_args) - - return boto3_session.client("bedrock-runtime", config=boto3_config) - def __init__(self, config: BedrockConfig) -> None: ModelRegistryHelper.__init__( self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS ) self._config = config - self._client = BedrockInferenceAdapter._create_bedrock_client(config) - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self._client = _create_bedrock_client(config) + self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> BaseClient: @@ -88,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: self.client.close() - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -258,7 +217,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): @staticmethod def _tool_parameters_to_input_schema( - tool_parameters: Optional[Dict[str, ToolParamDefinition]] + tool_parameters: Optional[Dict[str, ToolParamDefinition]], ) -> Dict: input_schema = {"type": "object"} if not tool_parameters: @@ -324,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) return tool_config - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -443,3 +402,50 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): else: # Ignored pass + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + +def _create_bedrock_client(config: BedrockConfig) -> BaseClient: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } + + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } + + boto3_config = Config(**config_args) + + session_args = { + k: v + for k, v in dict( + aws_access_key_id=config.aws_access_key_id, + aws_secret_access_key=config.aws_secret_access_key, + aws_session_token=config.aws_session_token, + region_name=config.region_name, + profile_name=config.profile_name, + ).items() + if v is not None + } + + boto3_session = boto3.session.Session(**session_args) + + return boto3_session.client("bedrock-runtime", config=boto3_config) diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 7fbac2e4b..3203e36f4 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -13,6 +13,7 @@ import boto3 from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ShieldsProtocolPrivate from .config import BedrockSafetyConfig @@ -25,7 +26,7 @@ BEDROCK_SUPPORTED_SHIELDS = [ ] -class BedrockSafetyAdapter(Safety): +class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: if not config.aws_profile: raise ValueError(f"Missing boto_client aws_profile in model info::{config}") @@ -45,19 +46,15 @@ class BedrockSafetyAdapter(Safety): pass async def register_shield(self, shield: ShieldDef) -> None: - if shield.type not in BEDROCK_SUPPORTED_SHIELDS: - raise ValueError(f"Unsupported safety shield type: {shield.type}") + raise ValueError("Registering dynamic shields is not supported") - shield_params = shield.params - if "guardrailIdentifier" not in shield_params: - raise ValueError( - "Error running request for BedrockGaurdrails:Missing GuardrailID in request" - ) - - if "guardrailVersion" not in shield_params: - raise ValueError( - "Error running request for BedrockGaurdrails:Missing guardrailVersion in request" - ) + async def list_shields(self) -> List[ShieldDef]: + raise NotImplementedError( + """ + `list_shields` not implemented; this should read all guardrails from + bedrock and populate guardrailId and guardrailVersion in the ShieldDef. + """ + ) async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None