mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
fix session_ttl passing
This commit is contained in:
parent
c649bd9bdf
commit
82bbaec140
6 changed files with 17 additions and 11 deletions
|
@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
|||
]
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
|
|
|
@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3", "pytz"],
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.adapters.inference.bedrock",
|
||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||
),
|
||||
|
|
|
@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3", "pytz"],
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.adapters.safety.bedrock",
|
||||
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
||||
),
|
||||
|
|
|
@ -56,6 +56,7 @@ def create_bedrock_client(
|
|||
"aws_session_token": config.aws_session_token,
|
||||
"region_name": config.region_name,
|
||||
"profile_name": config.profile_name,
|
||||
"session_ttl": config.session_ttl,
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
|
@ -66,7 +67,9 @@ def create_bedrock_client(
|
|||
else:
|
||||
return (
|
||||
RefreshableBotoSession(
|
||||
region_name=config.region_name, profile_name=config.profile_name
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
session_ttl=config.session_ttl,
|
||||
)
|
||||
.refreshable_session()
|
||||
.client(service_name)
|
||||
|
|
|
@ -53,3 +53,7 @@ class BedrockBaseConfig(BaseModel):
|
|||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: Optional[int] = Field(
|
||||
default=3600,
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
import datetime
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytz
|
||||
from boto3 import Session
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
|
@ -90,9 +89,9 @@ class RefreshableBotoSession:
|
|||
"access_key": session_credentials.access_key,
|
||||
"secret_key": session_credentials.secret_key,
|
||||
"token": session_credentials.token,
|
||||
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
|
||||
.replace(tzinfo=pytz.utc)
|
||||
.isoformat(),
|
||||
"expiry_time": datetime.datetime.fromtimestamp(
|
||||
time() + self.session_ttl, datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
|
||||
return credentials
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue