diff --git a/docs/source/distribution_dev/building_distro.md b/docs/source/distribution_dev/building_distro.md index 82724c40d..314792e41 100644 --- a/docs/source/distribution_dev/building_distro.md +++ b/docs/source/distribution_dev/building_distro.md @@ -76,7 +76,7 @@ llama stack build --list-templates | | "meta-reference", | | | | "remote::weaviate" | | | | ], | | -| | "safety": "remote::together", | | +| | "safety": "meta-reference", | | | | "agents": "meta-reference", | | | | "telemetry": "meta-reference" | | | | } | | diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 9279d8df9..fdaa33192 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -48,18 +48,6 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", ), ), - remote_provider_spec( - api=Api.safety, - adapter=AdapterSpec( - adapter_type="together", - pip_packages=[ - "together", - ], - module="llama_stack.providers.remote.safety.together", - config_class="llama_stack.providers.remote.safety.together.TogetherSafetyConfig", - provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", - ), - ), InlineProviderSpec( api=Api.safety, provider_type="meta-reference/codeshield", diff --git a/llama_stack/providers/remote/safety/together/__init__.py b/llama_stack/providers/remote/safety/together/__init__.py deleted file mode 100644 index cd7450491..000000000 --- a/llama_stack/providers/remote/safety/together/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401 - - -async def get_adapter_impl(config: TogetherSafetyConfig, _deps): - from .together import TogetherSafetyImpl - - assert isinstance( - config, TogetherSafetyConfig - ), f"Unexpected config type: {type(config)}" - impl = TogetherSafetyImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/safety/together/config.py b/llama_stack/providers/remote/safety/together/config.py deleted file mode 100644 index 463b929f4..000000000 --- a/llama_stack/providers/remote/safety/together/config.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field - - -class TogetherProviderDataValidator(BaseModel): - together_api_key: str - - -@json_schema_type -class TogetherSafetyConfig(BaseModel): - url: str = Field( - default="https://api.together.xyz/v1", - description="The URL for the Together AI server", - ) - api_key: Optional[str] = Field( - default=None, - description="The Together AI API Key (default for the distribution, if any)", - ) diff --git a/llama_stack/providers/remote/safety/together/together.py b/llama_stack/providers/remote/safety/together/together.py deleted file mode 100644 index 9f92626af..000000000 --- a/llama_stack/providers/remote/safety/together/together.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from together import Together - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ShieldsProtocolPrivate - -from .config import TogetherSafetyConfig - - -TOGETHER_SHIELD_MODEL_MAP = { - "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", - "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", -} - - -class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate): - def __init__(self, config: TogetherSafetyConfig) -> None: - self.config = config - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - pass - - async def register_shield(self, shield: ShieldDef) -> None: - raise ValueError("Registering dynamic shields is not supported") - - async def list_shields(self) -> List[ShieldDef]: - return [ - ShieldDef( - identifier=ShieldType.llama_guard.value, - shield_type=ShieldType.llama_guard.value, - params={}, - ) - ] - - async def run_shield( - self, identifier: str, messages: List[Message], params: Dict[str, Any] = None - ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(identifier) - if not shield_def: - raise ValueError(f"Unknown shield {identifier}") - - model = shield_def.params.get("model", "llama_guard") - if model not in TOGETHER_SHIELD_MODEL_MAP: - raise ValueError(f"Unsupported safety model: {model}") - - together_api_key = None - if self.config.api_key is not None: - together_api_key = self.config.api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - - # messages can have role assistant or user - api_messages = [] - for message in messages: - if message.role in (Role.user.value, Role.assistant.value): - api_messages.append({"role": message.role, "content": message.content}) - - violation = await get_safety_response( - together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages - ) - return RunShieldResponse(violation=violation) - - -async def get_safety_response( - api_key: str, model_name: str, messages: List[Dict[str, str]] -) -> Optional[SafetyViolation]: - client = Together(api_key=api_key) - response = client.chat.completions.create(messages=messages, model=model_name) - if len(response.choices) == 0: - return None - - response_text = response.choices[0].message.content - if response_text == "safe": - return None - - parts = response_text.split("\n") - if len(parts) != 2: - return None - - if parts[0] == "unsafe": - return SafetyViolation( - violation_level=ViolationLevel.ERROR, - metadata={"violation_type": parts[1]}, - ) - - return None diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index fb47b290d..88fe3d2ca 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "together", - "safety": "together", + "safety": "meta_reference", }, id="together", marks=pytest.mark.together, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 4789558ff..de1829355 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -12,12 +12,10 @@ from llama_stack.providers.inline.meta_reference.safety import ( LlamaGuardShieldConfig, SafetyConfig, ) -from llama_stack.providers.remote.safety.together import TogetherSafetyConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture -from ..env import get_env_or_fail @pytest.fixture(scope="session") @@ -49,23 +47,7 @@ def safety_meta_reference(safety_model) -> ProviderFixture: ) -@pytest.fixture(scope="session") -def safety_together() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="together", - provider_type="remote::together", - config=TogetherSafetyConfig().model_dump(), - ) - ], - provider_data=dict( - together_api_key=get_env_or_fail("TOGETHER_API_KEY"), - ), - ) - - -SAFETY_FIXTURES = ["meta_reference", "together", "remote"] +SAFETY_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index fe48e4586..05e59f677 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -6,6 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - safety: remote::together + safety: meta-reference agents: meta-reference telemetry: meta-reference