fix session_ttl passing

This commit is contained in:
Dinesh Yeduguru 2024-11-06 14:37:40 -08:00
parent c649bd9bdf
commit 82bbaec140
6 changed files with 17 additions and 11 deletions

View file

@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
]
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
model = shield_def.params.get("model", "llama_guard")
if model not in TOGETHER_SHIELD_MODEL_MAP:

View file

@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3", "pytz"],
pip_packages=["boto3"],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),

View file

@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3", "pytz"],
pip_packages=["boto3"],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
),

View file

@ -56,6 +56,7 @@ def create_bedrock_client(
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
"session_ttl": config.session_ttl,
}
# Remove None values
@ -66,7 +67,9 @@ def create_bedrock_client(
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
region_name=config.region_name,
profile_name=config.profile_name,
session_ttl=config.session_ttl,
)
.refreshable_session()
.client(service_name)

View file

@ -53,3 +53,7 @@ class BedrockBaseConfig(BaseModel):
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
session_ttl: Optional[int] = Field(
default=3600,
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import datetime
from time import time
from uuid import uuid4
import pytz
from boto3 import Session
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
@ -90,9 +89,9 @@ class RefreshableBotoSession:
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat(),
"expiry_time": datetime.datetime.fromtimestamp(
time() + self.session_ttl, datetime.timezone.utc
).isoformat(),
}
return credentials