mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
address raghu's feedback
This commit is contained in:
parent
2101cb08c7
commit
c649bd9bdf
7 changed files with 144 additions and 154 deletions
|
@ -6,9 +6,7 @@
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
import boto3
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from botocore.config import Config
|
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -18,7 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
BEDROCK_SUPPORTED_MODELS = {
|
||||||
|
@ -36,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = _create_bedrock_client(config)
|
self._client = create_bedrock_client(config)
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -439,50 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|
||||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
|
||||||
retries_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
total_max_attempts=config.total_max_attempts,
|
|
||||||
mode=config.retry_mode,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
config_args = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
region_name=config.region_name,
|
|
||||||
retries=retries_config if retries_config else None,
|
|
||||||
connect_timeout=config.connect_timeout,
|
|
||||||
read_timeout=config.read_timeout,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
boto3_config = Config(**config_args)
|
|
||||||
|
|
||||||
session_args = {
|
|
||||||
"aws_access_key_id": config.aws_access_key_id,
|
|
||||||
"aws_secret_access_key": config.aws_secret_access_key,
|
|
||||||
"aws_session_token": config.aws_session_token,
|
|
||||||
"region_name": config.region_name,
|
|
||||||
"profile_name": config.profile_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Remove None values
|
|
||||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
RefreshableBotoSession(
|
|
||||||
region_name=config.region_name, profile_name=config.profile_name
|
|
||||||
)
|
|
||||||
.refreshable_session()
|
|
||||||
.client("bedrock-runtime")
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,53 +3,12 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BedrockConfig(BaseModel):
|
class BedrockConfig(BedrockBaseConfig):
|
||||||
aws_access_key_id: Optional[str] = Field(
|
pass
|
||||||
default=None,
|
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
|
||||||
)
|
|
||||||
aws_secret_access_key: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
|
||||||
)
|
|
||||||
aws_session_token: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
|
||||||
)
|
|
||||||
region_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
|
||||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
|
||||||
)
|
|
||||||
profile_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The profile name that contains credentials to use."
|
|
||||||
"Default use environment variable: AWS_PROFILE",
|
|
||||||
)
|
|
||||||
total_max_attempts: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
|
||||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
|
||||||
)
|
|
||||||
retry_mode: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="A string representing the type of retries Boto3 will perform."
|
|
||||||
"Default use environment variable: AWS_RETRY_MODE",
|
|
||||||
)
|
|
||||||
connect_timeout: Optional[float] = Field(
|
|
||||||
default=60,
|
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
|
||||||
"The default is 60 seconds.",
|
|
||||||
)
|
|
||||||
read_timeout: Optional[float] = Field(
|
|
||||||
default=60,
|
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
|
||||||
"The default is 60 seconds.",
|
|
||||||
)
|
|
||||||
|
|
|
@ -9,12 +9,10 @@ import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
@ -27,30 +25,6 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
|
|
||||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
|
||||||
session_args = {
|
|
||||||
"aws_access_key_id": config.aws_access_key_id,
|
|
||||||
"aws_secret_access_key": config.aws_secret_access_key,
|
|
||||||
"aws_session_token": config.aws_session_token,
|
|
||||||
"region_name": config.region_name,
|
|
||||||
"profile_name": config.profile_name,
|
|
||||||
}
|
|
||||||
# Remove None values
|
|
||||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
|
||||||
return boto3_session.client(name)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
RefreshableBotoSession(
|
|
||||||
region_name=config.region_name, profile_name=config.profile_name
|
|
||||||
)
|
|
||||||
.refreshable_session()
|
|
||||||
.client(name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -58,10 +32,8 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
self.bedrock_runtime_client = _create_bedrock_client(
|
self.bedrock_runtime_client = create_bedrock_client(self.config)
|
||||||
self.config, "bedrock-runtime"
|
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
|
||||||
)
|
|
||||||
self.bedrock_client = _create_bedrock_client(self.config, "bedrock")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||||
|
|
||||||
|
|
|
@ -4,32 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyConfig(BaseModel):
|
@json_schema_type
|
||||||
"""Configuration information for a guardrail that you want to use in the request."""
|
class BedrockSafetyConfig(BedrockBaseConfig):
|
||||||
|
pass
|
||||||
aws_access_key_id: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
|
||||||
)
|
|
||||||
aws_secret_access_key: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
|
||||||
)
|
|
||||||
aws_session_token: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
|
||||||
)
|
|
||||||
region_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
|
||||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
|
||||||
)
|
|
||||||
profile_name: str = Field(
|
|
||||||
default="default",
|
|
||||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
|
||||||
)
|
|
||||||
|
|
73
llama_stack/providers/utils/bedrock/client.py
Normal file
73
llama_stack/providers/utils/bedrock/client.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.client import BaseClient
|
||||||
|
from botocore.config import Config
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
|
||||||
|
RefreshableBotoSession,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_bedrock_client(
|
||||||
|
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
|
||||||
|
) -> BaseClient:
|
||||||
|
"""Creates a boto3 client for Bedrock services with the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The Bedrock configuration containing AWS credentials and settings
|
||||||
|
service_name: The AWS service name to create client for (default: "bedrock-runtime")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured boto3 client
|
||||||
|
"""
|
||||||
|
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||||
|
retries_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
total_max_attempts=config.total_max_attempts,
|
||||||
|
mode=config.retry_mode,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
config_args = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
region_name=config.region_name,
|
||||||
|
retries=retries_config if retries_config else None,
|
||||||
|
connect_timeout=config.connect_timeout,
|
||||||
|
read_timeout=config.read_timeout,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
|
session_args = {
|
||||||
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
|
"aws_secret_access_key": config.aws_secret_access_key,
|
||||||
|
"aws_session_token": config.aws_session_token,
|
||||||
|
"region_name": config.region_name,
|
||||||
|
"profile_name": config.profile_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
return boto3_session.client(service_name, config=boto3_config)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
RefreshableBotoSession(
|
||||||
|
region_name=config.region_name, profile_name=config.profile_name
|
||||||
|
)
|
||||||
|
.refreshable_session()
|
||||||
|
.client(service_name)
|
||||||
|
)
|
55
llama_stack/providers/utils/bedrock/config.py
Normal file
55
llama_stack/providers/utils/bedrock/config.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BedrockBaseConfig(BaseModel):
|
||||||
|
aws_access_key_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
|
)
|
||||||
|
aws_secret_access_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||||
|
)
|
||||||
|
aws_session_token: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||||
|
)
|
||||||
|
region_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||||
|
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||||
|
)
|
||||||
|
profile_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The profile name that contains credentials to use."
|
||||||
|
"Default use environment variable: AWS_PROFILE",
|
||||||
|
)
|
||||||
|
total_max_attempts: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||||
|
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||||
|
)
|
||||||
|
retry_mode: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="A string representing the type of retries Boto3 will perform."
|
||||||
|
"Default use environment variable: AWS_RETRY_MODE",
|
||||||
|
)
|
||||||
|
connect_timeout: Optional[float] = Field(
|
||||||
|
default=60,
|
||||||
|
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||||
|
"The default is 60 seconds.",
|
||||||
|
)
|
||||||
|
read_timeout: Optional[float] = Field(
|
||||||
|
default=60,
|
||||||
|
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||||
|
"The default is 60 seconds.",
|
||||||
|
)
|
|
@ -31,7 +31,7 @@ class RefreshableBotoSession:
|
||||||
profile_name: str = None,
|
profile_name: str = None,
|
||||||
sts_arn: str = None,
|
sts_arn: str = None,
|
||||||
session_name: str = None,
|
session_name: str = None,
|
||||||
session_ttl: int = 3000,
|
session_ttl: int = 30000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize `RefreshableBotoSession`
|
Initialize `RefreshableBotoSession`
|
Loading…
Add table
Add a link
Reference in a new issue