mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Update bedrock
This commit is contained in:
parent
fe0dabe596
commit
a33cafc2fe
2 changed files with 63 additions and 60 deletions
|
@ -26,57 +26,16 @@ BEDROCK_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this is not quite tested after the recent refactors
|
||||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|
||||||
retries_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
total_max_attempts=config.total_max_attempts,
|
|
||||||
mode=config.retry_mode,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
config_args = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
region_name=config.region_name,
|
|
||||||
retries=retries_config if retries_config else None,
|
|
||||||
connect_timeout=config.connect_timeout,
|
|
||||||
read_timeout=config.read_timeout,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
boto3_config = Config(**config_args)
|
|
||||||
|
|
||||||
session_args = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
|
||||||
aws_session_token=config.aws_session_token,
|
|
||||||
region_name=config.region_name,
|
|
||||||
profile_name=config.profile_name,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
|
||||||
|
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
|
||||||
|
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
self._client = _create_bedrock_client(config)
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.formatter = ChatFormat(tokenizer)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> BaseClient:
|
def client(self) -> BaseClient:
|
||||||
|
@ -88,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -258,7 +217,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_parameters_to_input_schema(
|
def _tool_parameters_to_input_schema(
|
||||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
input_schema = {"type": "object"}
|
input_schema = {"type": "object"}
|
||||||
if not tool_parameters:
|
if not tool_parameters:
|
||||||
|
@ -324,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -443,3 +402,50 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
else:
|
else:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
|
retries_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
total_max_attempts=config.total_max_attempts,
|
||||||
|
mode=config.retry_mode,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
config_args = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
region_name=config.region_name,
|
||||||
|
retries=retries_config if retries_config else None,
|
||||||
|
connect_timeout=config.connect_timeout,
|
||||||
|
read_timeout=config.read_timeout,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
|
session_args = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
aws_access_key_id=config.aws_access_key_id,
|
||||||
|
aws_secret_access_key=config.aws_secret_access_key,
|
||||||
|
aws_session_token=config.aws_session_token,
|
||||||
|
region_name=config.region_name,
|
||||||
|
profile_name=config.profile_name,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
|
||||||
|
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||||
|
|
|
@ -13,6 +13,7 @@ import boto3
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
if not config.aws_profile:
|
if not config.aws_profile:
|
||||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||||
|
@ -45,19 +46,15 @@ class BedrockSafetyAdapter(Safety):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
|
||||||
|
|
||||||
shield_params = shield.params
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
if "guardrailIdentifier" not in shield_params:
|
raise NotImplementedError(
|
||||||
raise ValueError(
|
"""
|
||||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
`list_shields` not implemented; this should read all guardrails from
|
||||||
)
|
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
||||||
|
"""
|
||||||
if "guardrailVersion" not in shield_params:
|
)
|
||||||
raise ValueError(
|
|
||||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue