diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index 683ae622d..4bbb1f2f1 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -9,12 +9,13 @@ from typing import List from llama_models.llama3.api.datatypes import Message, Role, UserMessage from termcolor import cprint -from llama_toolchain.safety.api.datatypes import ( +from llama_toolchain.safety.api import ( OnViolationAction, + RunShieldRequest, + Safety, ShieldDefinition, ShieldResponse, ) -from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety class SafetyException(Exception): # noqa: N818 diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 7294392a2..df8ee76f1 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -7,13 +7,13 @@ import inspect from typing import Dict, List -from llama_toolchain.agentic_system.api.endpoints import AgenticSystem +from llama_toolchain.agentic_system.api import AgenticSystem from llama_toolchain.agentic_system.providers import available_agentic_system_providers -from llama_toolchain.inference.api.endpoints import Inference +from llama_toolchain.inference.api import Inference from llama_toolchain.inference.providers import available_inference_providers -from llama_toolchain.memory.api.endpoints import Memory +from llama_toolchain.memory.api import Memory from llama_toolchain.memory.providers import available_memory_providers -from llama_toolchain.safety.api.endpoints import Safety +from llama_toolchain.safety.api import Safety from llama_toolchain.safety.providers import available_safety_providers from .datatypes import ( diff --git a/llama_toolchain/tools/safety.py b/llama_toolchain/tools/safety.py index aab67801d..24051af8a 100644 --- a/llama_toolchain/tools/safety.py +++ b/llama_toolchain/tools/safety.py @@ -9,8 +9,7 @@ from typing import List from llama_toolchain.agentic_system.meta_reference.safety import ShieldRunnerMixin from llama_toolchain.inference.api import Message -from llama_toolchain.safety.api.datatypes import ShieldDefinition -from llama_toolchain.safety.api.endpoints import Safety +from llama_toolchain.safety.api import Safety, ShieldDefinition from .builtin import BaseTool