fix safety using inference

This commit is contained in:
Xi Yan 2024-09-25 11:27:57 -07:00
parent d442af0818
commit 0d19a026a8
2 changed files with 4 additions and 0 deletions

View file

@ -18,6 +18,7 @@ from transformers import (
)
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import Message, Role
SAFE_RESPONSE = "safe"

View file

@ -28,6 +28,9 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,