From 82bbaec1407a08d171104d020f1308c90c570ffe Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 6 Nov 2024 14:37:40 -0800 Subject: [PATCH] fix session_ttl passing --- .../providers/adapters/safety/together/together.py | 6 +++--- llama_stack/providers/registry/inference.py | 2 +- llama_stack/providers/registry/safety.py | 2 +- llama_stack/providers/utils/bedrock/client.py | 5 ++++- llama_stack/providers/utils/bedrock/config.py | 4 ++++ .../providers/utils/bedrock/refreshable_boto_session.py | 9 ++++----- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index da45ed5b8..9f92626af 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat ] async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") model = shield_def.params.get("model", "llama_guard") if model not in TOGETHER_SHIELD_MODEL_MAP: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 22babb076..88265f1b4 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="bedrock", - pip_packages=["boto3", "pytz"], + pip_packages=["boto3"], module="llama_stack.providers.adapters.inference.bedrock", config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig", ), diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 79ac4ca2b..3fa62479a 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]: api=Api.safety, adapter=AdapterSpec( adapter_type="bedrock", - pip_packages=["boto3", "pytz"], + pip_packages=["boto3"], module="llama_stack.providers.adapters.safety.bedrock", config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig", ), diff --git a/llama_stack/providers/utils/bedrock/client.py b/llama_stack/providers/utils/bedrock/client.py index 725c248c1..77781c729 100644 --- a/llama_stack/providers/utils/bedrock/client.py +++ b/llama_stack/providers/utils/bedrock/client.py @@ -56,6 +56,7 @@ def create_bedrock_client( "aws_session_token": config.aws_session_token, "region_name": config.region_name, "profile_name": config.profile_name, + "session_ttl": config.session_ttl, } # Remove None values @@ -66,7 +67,9 @@ def create_bedrock_client( else: return ( RefreshableBotoSession( - region_name=config.region_name, profile_name=config.profile_name + region_name=config.region_name, + profile_name=config.profile_name, + session_ttl=config.session_ttl, ) .refreshable_session() .client(service_name) diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index 38f1bd756..55c5582a1 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -53,3 +53,7 @@ class BedrockBaseConfig(BaseModel): description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." "The default is 60 seconds.", ) + session_ttl: Optional[int] = Field( + default=3600, + description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", + ) diff --git a/llama_stack/providers/utils/bedrock/refreshable_boto_session.py b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py index fedffb618..f37563930 100644 --- a/llama_stack/providers/utils/bedrock/refreshable_boto_session.py +++ b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime +import datetime from time import time from uuid import uuid4 -import pytz from boto3 import Session from botocore.credentials import RefreshableCredentials from botocore.session import get_session @@ -90,9 +89,9 @@ class RefreshableBotoSession: "access_key": session_credentials.access_key, "secret_key": session_credentials.secret_key, "token": session_credentials.token, - "expiry_time": datetime.fromtimestamp(time() + self.session_ttl) - .replace(tzinfo=pytz.utc) - .isoformat(), + "expiry_time": datetime.datetime.fromtimestamp( + time() + self.session_ttl, datetime.timezone.utc + ).isoformat(), } return credentials