mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
fix session_ttl passing
This commit is contained in:
parent
c649bd9bdf
commit
82bbaec140
6 changed files with 17 additions and 11 deletions
|
@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
||||||
]
|
]
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
model = shield_def.params.get("model", "llama_guard")
|
model = shield_def.params.get("model", "llama_guard")
|
||||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||||
|
|
|
@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="bedrock",
|
adapter_type="bedrock",
|
||||||
pip_packages=["boto3", "pytz"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.adapters.inference.bedrock",
|
module="llama_stack.providers.adapters.inference.bedrock",
|
||||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="bedrock",
|
adapter_type="bedrock",
|
||||||
pip_packages=["boto3", "pytz"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.adapters.safety.bedrock",
|
module="llama_stack.providers.adapters.safety.bedrock",
|
||||||
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -56,6 +56,7 @@ def create_bedrock_client(
|
||||||
"aws_session_token": config.aws_session_token,
|
"aws_session_token": config.aws_session_token,
|
||||||
"region_name": config.region_name,
|
"region_name": config.region_name,
|
||||||
"profile_name": config.profile_name,
|
"profile_name": config.profile_name,
|
||||||
|
"session_ttl": config.session_ttl,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Remove None values
|
# Remove None values
|
||||||
|
@ -66,7 +67,9 @@ def create_bedrock_client(
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
RefreshableBotoSession(
|
RefreshableBotoSession(
|
||||||
region_name=config.region_name, profile_name=config.profile_name
|
region_name=config.region_name,
|
||||||
|
profile_name=config.profile_name,
|
||||||
|
session_ttl=config.session_ttl,
|
||||||
)
|
)
|
||||||
.refreshable_session()
|
.refreshable_session()
|
||||||
.client(service_name)
|
.client(service_name)
|
||||||
|
|
|
@ -53,3 +53,7 @@ class BedrockBaseConfig(BaseModel):
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||||
"The default is 60 seconds.",
|
"The default is 60 seconds.",
|
||||||
)
|
)
|
||||||
|
session_ttl: Optional[int] = Field(
|
||||||
|
default=3600,
|
||||||
|
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||||
|
)
|
||||||
|
|
|
@ -4,11 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
import datetime
|
||||||
from time import time
|
from time import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytz
|
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
from botocore.credentials import RefreshableCredentials
|
from botocore.credentials import RefreshableCredentials
|
||||||
from botocore.session import get_session
|
from botocore.session import get_session
|
||||||
|
@ -90,9 +89,9 @@ class RefreshableBotoSession:
|
||||||
"access_key": session_credentials.access_key,
|
"access_key": session_credentials.access_key,
|
||||||
"secret_key": session_credentials.secret_key,
|
"secret_key": session_credentials.secret_key,
|
||||||
"token": session_credentials.token,
|
"token": session_credentials.token,
|
||||||
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
|
"expiry_time": datetime.datetime.fromtimestamp(
|
||||||
.replace(tzinfo=pytz.utc)
|
time() + self.session_ttl, datetime.timezone.utc
|
||||||
.isoformat(),
|
).isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return credentials
|
return credentials
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue