Update bedrock

This commit is contained in:
Ashwin Bharambe 2024-10-10 10:19:06 -07:00
parent fe0dabe596
commit a33cafc2fe
2 changed files with 63 additions and 60 deletions

View file

@ -26,57 +26,16 @@ BEDROCK_SUPPORTED_MODELS = {
} }
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
) )
self._config = config self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config) self._client = _create_bedrock_client(config)
tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(tokenizer)
@property @property
def client(self) -> BaseClient: def client(self) -> BaseClient:
@ -88,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -258,7 +217,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod @staticmethod
def _tool_parameters_to_input_schema( def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]] tool_parameters: Optional[Dict[str, ToolParamDefinition]],
) -> Dict: ) -> Dict:
input_schema = {"type": "object"} input_schema = {"type": "object"}
if not tool_parameters: if not tool_parameters:
@ -324,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) )
return tool_config return tool_config
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -443,3 +402,50 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else: else:
# Ignored # Ignored
pass pass
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)

View file

@ -13,6 +13,7 @@ import boto3
from llama_stack.apis.safety import * # noqa from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
@ -25,7 +26,7 @@ BEDROCK_SUPPORTED_SHIELDS = [
] ]
class BedrockSafetyAdapter(Safety): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile: if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}") raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
@ -45,18 +46,14 @@ class BedrockSafetyAdapter(Safety):
pass pass
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS: raise ValueError("Registering dynamic shields is not supported")
raise ValueError(f"Unsupported safety shield type: {shield.type}")
shield_params = shield.params async def list_shields(self) -> List[ShieldDef]:
if "guardrailIdentifier" not in shield_params: raise NotImplementedError(
raise ValueError( """
"Error running request for BedrockGaurdrails:Missing GuardrailID in request" `list_shields` not implemented; this should read all guardrails from
) bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
"""
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
) )
async def run_shield( async def run_shield(