From b85d675c6f62ade97bc4fbf19fa7c1204637c319 Mon Sep 17 00:00:00 2001 From: Yogish Baliga Date: Fri, 20 Sep 2024 09:35:01 -0700 Subject: [PATCH] Adding safety adapter for Together --- llama_stack/apis/safety/client.py | 9 ++- llama_stack/distribution/request_headers.py | 23 +++--- llama_stack/distribution/server/server.py | 35 ++++++--- .../templates/local-together-build.yaml | 2 +- .../adapters/safety/together/__init__.py | 18 +++++ .../adapters/safety/together/config.py | 26 +++++++ .../adapters/safety/together/together.py | 78 +++++++++++++++++++ llama_stack/providers/registry/safety.py | 20 ++++- 8 files changed, 188 insertions(+), 23 deletions(-) create mode 100644 llama_stack/providers/adapters/safety/together/__init__.py create mode 100644 llama_stack/providers/adapters/safety/together/config.py create mode 100644 llama_stack/providers/adapters/safety/together/together.py diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 29bb94420..38af9589c 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -49,7 +49,14 @@ class SafetyClient(Safety): shield_type=shield_type, messages=[encodable_dict(m) for m in messages], ), - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "X-LlamaStack-ProviderData": json.dumps( + { + "together_api_key": "1882f9a484fc7c6ce3e4dc90272d5db52346c93838daab3d704803181f396b22" + } + ), + }, timeout=20, ) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 5a4fb19a0..27b8b531f 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -6,7 +6,7 @@ import json import threading -from typing import Any, Dict, Optional +from typing import Any, Dict, List from .utils.dynamic import instantiate_class_type @@ -17,8 +17,8 @@ def get_request_provider_data() -> Any: return getattr(_THREAD_LOCAL, "provider_data", None) -def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]): - if not validator_class: +def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): + if not validator_classes: return keys = [ @@ -39,11 +39,12 @@ def set_request_provider_data(headers: Dict[str, str], validator_class: Optional print("Provider data not encoded as a JSON object!", val) return - validator = instantiate_class_type(validator_class) - try: - provider_data = validator(**val) - except Exception as e: - print("Error parsing provider data", e) - return - - _THREAD_LOCAL.provider_data = provider_data + for validator_class in validator_classes: + validator = instantiate_class_type(validator_class) + try: + provider_data = validator(**val) + if provider_data: + _THREAD_LOCAL.provider_data = provider_data + return + except Exception as e: + print("Error parsing provider data", e) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 1d77e1e4c..7a3e6276c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -15,6 +15,7 @@ from collections.abc import ( AsyncIterator as AsyncIteratorABC, ) from contextlib import asynccontextmanager +from http import HTTPStatus from ssl import SSLError from typing import ( Any, @@ -88,7 +89,7 @@ async def global_exception_handler(request: Request, exc: Exception): ) -def translate_exception(exc: Exception) -> HTTPException: +def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]: if isinstance(exc, ValidationError): exc = RequestValidationError(exc.raw_errors) @@ -207,7 +208,7 @@ def create_dynamic_passthrough( def create_dynamic_typed_route( - func: Any, method: str, provider_data_validator: Optional[str] + func: Any, method: str, provider_data_validators: List[str] ): hints = get_type_hints(func) response_model = hints.get("return") @@ -223,7 +224,7 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validator) + set_request_provider_data(request.headers, provider_data_validators) async def sse_generator(event_gen): try: @@ -254,7 +255,7 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validator) + set_request_provider_data(request.headers, provider_data_validators) try: return ( @@ -415,6 +416,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() + # Health check is added to enable deploying the docker container image on Kubernetes which require + # a health check that can return 200 for readiness and liveness check + class HealthCheck(BaseModel): + status: str = "OK" + + @app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck) + async def healthcheck(): + return HealthCheck(status="OK") + impls, specs = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) @@ -454,15 +464,22 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): ) impl_method = getattr(impl, endpoint.name) + + validators = [] + if isinstance(provider_spec, AutoRoutedProviderSpec): + inner_specs = specs[provider_spec.routing_table_api].inner_specs + for spec in inner_specs: + if spec.provider_data_validator: + validators.append(spec.provider_data_validator) + elif not isinstance(provider_spec, RoutingTableProviderSpec): + if provider_spec.provider_data_validator: + validators.append(provider_spec.provider_data_validator) + getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route( impl_method, endpoint.method, - ( - provider_spec.provider_data_validator - if not isinstance(provider_spec, RoutingTableProviderSpec) - else None - ), + validators, ) ) diff --git a/llama_stack/distribution/templates/local-together-build.yaml b/llama_stack/distribution/templates/local-together-build.yaml index 1ab891518..ebf0bf1fb 100644 --- a/llama_stack/distribution/templates/local-together-build.yaml +++ b/llama_stack/distribution/templates/local-together-build.yaml @@ -4,7 +4,7 @@ distribution_spec: providers: inference: remote::together memory: meta-reference - safety: meta-reference + safety: remote::together agents: meta-reference telemetry: meta-reference image_type: conda diff --git a/llama_stack/providers/adapters/safety/together/__init__.py b/llama_stack/providers/adapters/safety/together/__init__.py new file mode 100644 index 000000000..cd7450491 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401 + + +async def get_adapter_impl(config: TogetherSafetyConfig, _deps): + from .together import TogetherSafetyImpl + + assert isinstance( + config, TogetherSafetyConfig + ), f"Unexpected config type: {type(config)}" + impl = TogetherSafetyImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/safety/together/config.py b/llama_stack/providers/adapters/safety/together/config.py new file mode 100644 index 000000000..463b929f4 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/config.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +class TogetherProviderDataValidator(BaseModel): + together_api_key: str + + +@json_schema_type +class TogetherSafetyConfig(BaseModel): + url: str = Field( + default="https://api.together.xyz/v1", + description="The URL for the Together AI server", + ) + api_key: Optional[str] = Field( + default=None, + description="The Together AI API Key (default for the distribution, if any)", + ) diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py new file mode 100644 index 000000000..223377073 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from together import Together + +from llama_stack.distribution.request_headers import get_request_provider_data + +from .config import TogetherProviderDataValidator, TogetherSafetyConfig + + +class TogetherSafetyImpl(Safety): + def __init__(self, config: TogetherSafetyConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def run_shield( + self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + ) -> RunShieldResponse: + if shield_type != "llama_guard": + raise ValueError(f"shield type {shield_type} is not supported") + + provider_data = get_request_provider_data() + + together_api_key = None + if provider_data is not None: + if not isinstance(provider_data, TogetherProviderDataValidator): + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + + together_api_key = provider_data.together_api_key + if not together_api_key: + together_api_key = self.config.api_key + + if not together_api_key: + raise ValueError("The API key must be provider in the header or config") + + # messages can have role assistant or user + api_messages = [] + for message in messages: + if message.role in (Role.user.value, Role.assistant.value): + api_messages.append({"role": message.role, "content": message.content}) + + violation = await get_safety_response(together_api_key, api_messages) + return RunShieldResponse(violation=violation) + + +async def get_safety_response( + api_key: str, messages: List[Dict[str, str]] +) -> Optional[SafetyViolation]: + client = Together(api_key=api_key) + response = client.chat.completions.create( + messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B" + ) + if len(response.choices) == 0: + return None + + response_text = response.choices[0].message.content + if response_text == "safe": + return None + + parts = response_text.split("\n") + if len(parts) != 2: + return None + + if parts[0] == "unsafe": + return SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="unsafe", + metadata={"violation_type": parts[1]}, + ) + + return None diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 6cfc69787..0a012b1df 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: @@ -34,4 +40,16 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_id="together", + pip_packages=[ + "together", + ], + module="llama_stack.providers.adapters.safety.together", + config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig", + provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", + ), + ), ]