mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
rebase on top of Dinesh's refactor
This commit is contained in:
parent
a7f728e41c
commit
7507cd487f
2 changed files with 16 additions and 32 deletions
|
@ -127,27 +127,19 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
if shield.shield_type != ShieldType.llama_guard.value:
|
||||||
|
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
|
||||||
return [
|
|
||||||
ShieldDef(
|
|
||||||
identifier=ShieldType.llama_guard.value,
|
|
||||||
shield_type=ShieldType.llama_guard.value,
|
|
||||||
params={},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
identifier: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(identifier)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield_def:
|
if not shield:
|
||||||
raise ValueError(f"Unknown shield {identifier}")
|
raise ValueError(f"Unknown shield {shield_id}")
|
||||||
|
|
||||||
messages = messages.copy()
|
messages = messages.copy()
|
||||||
# some shields like llama-guard require the first message to be a user message
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -35,27 +35,19 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
if shield.shield_type != ShieldType.prompt_guard.value:
|
||||||
|
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
|
||||||
return [
|
|
||||||
ShieldDef(
|
|
||||||
identifier=ShieldType.prompt_guard.value,
|
|
||||||
shield_type=ShieldType.prompt_guard.value,
|
|
||||||
params={},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
identifier: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(identifier)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield_def:
|
if not shield:
|
||||||
raise ValueError(f"Unknown shield {identifier}")
|
raise ValueError(f"Unknown shield {shield_id}")
|
||||||
|
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue