mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Adding safety adapter for Together
This commit is contained in:
parent
0d2eb3bd25
commit
b85d675c6f
8 changed files with 188 additions and 23 deletions
|
@ -49,7 +49,14 @@ class SafetyClient(Safety):
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
messages=[encodable_dict(m) for m in messages],
|
messages=[encodable_dict(m) for m in messages],
|
||||||
),
|
),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-LlamaStack-ProviderData": json.dumps(
|
||||||
|
{
|
||||||
|
"together_api_key": "1882f9a484fc7c6ce3e4dc90272d5db52346c93838daab3d704803181f396b22"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
@ -17,8 +17,8 @@ def get_request_provider_data() -> Any:
|
||||||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
return getattr(_THREAD_LOCAL, "provider_data", None)
|
||||||
|
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
|
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
|
||||||
if not validator_class:
|
if not validator_classes:
|
||||||
return
|
return
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
|
@ -39,11 +39,12 @@ def set_request_provider_data(headers: Dict[str, str], validator_class: Optional
|
||||||
print("Provider data not encoded as a JSON object!", val)
|
print("Provider data not encoded as a JSON object!", val)
|
||||||
return
|
return
|
||||||
|
|
||||||
validator = instantiate_class_type(validator_class)
|
for validator_class in validator_classes:
|
||||||
try:
|
validator = instantiate_class_type(validator_class)
|
||||||
provider_data = validator(**val)
|
try:
|
||||||
except Exception as e:
|
provider_data = validator(**val)
|
||||||
print("Error parsing provider data", e)
|
if provider_data:
|
||||||
return
|
_THREAD_LOCAL.provider_data = provider_data
|
||||||
|
return
|
||||||
_THREAD_LOCAL.provider_data = provider_data
|
except Exception as e:
|
||||||
|
print("Error parsing provider data", e)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from collections.abc import (
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
)
|
)
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from http import HTTPStatus
|
||||||
from ssl import SSLError
|
from ssl import SSLError
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -88,7 +89,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> HTTPException:
|
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||||
if isinstance(exc, ValidationError):
|
if isinstance(exc, ValidationError):
|
||||||
exc = RequestValidationError(exc.raw_errors)
|
exc = RequestValidationError(exc.raw_errors)
|
||||||
|
|
||||||
|
@ -207,7 +208,7 @@ def create_dynamic_passthrough(
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(
|
def create_dynamic_typed_route(
|
||||||
func: Any, method: str, provider_data_validator: Optional[str]
|
func: Any, method: str, provider_data_validators: List[str]
|
||||||
):
|
):
|
||||||
hints = get_type_hints(func)
|
hints = get_type_hints(func)
|
||||||
response_model = hints.get("return")
|
response_model = hints.get("return")
|
||||||
|
@ -223,7 +224,7 @@ def create_dynamic_typed_route(
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
set_request_provider_data(request.headers, provider_data_validator)
|
set_request_provider_data(request.headers, provider_data_validators)
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
|
@ -254,7 +255,7 @@ def create_dynamic_typed_route(
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
set_request_provider_data(request.headers, provider_data_validator)
|
set_request_provider_data(request.headers, provider_data_validators)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
|
@ -415,6 +416,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Health check is added to enable deploying the docker container image on Kubernetes which require
|
||||||
|
# a health check that can return 200 for readiness and liveness check
|
||||||
|
class HealthCheck(BaseModel):
|
||||||
|
status: str = "OK"
|
||||||
|
|
||||||
|
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
|
||||||
|
async def healthcheck():
|
||||||
|
return HealthCheck(status="OK")
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
@ -454,15 +464,22 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
)
|
)
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
|
||||||
|
validators = []
|
||||||
|
if isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
|
inner_specs = specs[provider_spec.routing_table_api].inner_specs
|
||||||
|
for spec in inner_specs:
|
||||||
|
if spec.provider_data_validator:
|
||||||
|
validators.append(spec.provider_data_validator)
|
||||||
|
elif not isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
|
if provider_spec.provider_data_validator:
|
||||||
|
validators.append(provider_spec.provider_data_validator)
|
||||||
|
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
endpoint.method,
|
||||||
(
|
validators,
|
||||||
provider_spec.provider_data_validator
|
|
||||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference: remote::together
|
inference: remote::together
|
||||||
memory: meta-reference
|
memory: meta-reference
|
||||||
safety: meta-reference
|
safety: remote::together
|
||||||
agents: meta-reference
|
agents: meta-reference
|
||||||
telemetry: meta-reference
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: TogetherSafetyConfig, _deps):
|
||||||
|
from .together import TogetherSafetyImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, TogetherSafetyConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = TogetherSafetyImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
26
llama_stack/providers/adapters/safety/together/config.py
Normal file
26
llama_stack/providers/adapters/safety/together/config.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherProviderDataValidator(BaseModel):
|
||||||
|
together_api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TogetherSafetyConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default="https://api.together.xyz/v1",
|
||||||
|
description="The URL for the Together AI server",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Together AI API Key (default for the distribution, if any)",
|
||||||
|
)
|
78
llama_stack/providers/adapters/safety/together/together.py
Normal file
78
llama_stack/providers/adapters/safety/together/together.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from together import Together
|
||||||
|
|
||||||
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
|
|
||||||
|
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherSafetyImpl(Safety):
|
||||||
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
if shield_type != "llama_guard":
|
||||||
|
raise ValueError(f"shield type {shield_type} is not supported")
|
||||||
|
|
||||||
|
provider_data = get_request_provider_data()
|
||||||
|
|
||||||
|
together_api_key = None
|
||||||
|
if provider_data is not None:
|
||||||
|
if not isinstance(provider_data, TogetherProviderDataValidator):
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
if not together_api_key:
|
||||||
|
together_api_key = self.config.api_key
|
||||||
|
|
||||||
|
if not together_api_key:
|
||||||
|
raise ValueError("The API key must be provider in the header or config")
|
||||||
|
|
||||||
|
# messages can have role assistant or user
|
||||||
|
api_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.role in (Role.user.value, Role.assistant.value):
|
||||||
|
api_messages.append({"role": message.role, "content": message.content})
|
||||||
|
|
||||||
|
violation = await get_safety_response(together_api_key, api_messages)
|
||||||
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_safety_response(
|
||||||
|
api_key: str, messages: List[Dict[str, str]]
|
||||||
|
) -> Optional[SafetyViolation]:
|
||||||
|
client = Together(api_key=api_key)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B"
|
||||||
|
)
|
||||||
|
if len(response.choices) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response_text = response.choices[0].message.content
|
||||||
|
if response_text == "safe":
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = response_text.split("\n")
|
||||||
|
if len(parts) != 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if parts[0] == "unsafe":
|
||||||
|
return SafetyViolation(
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
user_message="unsafe",
|
||||||
|
metadata={"violation_type": parts[1]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
|
@ -6,7 +6,13 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
@ -34,4 +40,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.safety,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="together",
|
||||||
|
pip_packages=[
|
||||||
|
"together",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.adapters.safety.together",
|
||||||
|
config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue