mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-16 06:27:58 +00:00
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Integration Tests (Replay) / discover-tests (push) Successful in 13s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 24s
Test External API and Providers / test-external (venv) (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Update ReadTheDocs / update-readthedocs (push) Failing after 9s
Python Package Build Test / build (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 27s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 21s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 35s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 1m2s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 1m4s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 1m2s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 7s
Pre-commit / pre-commit (push) Successful in 2m21s
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> Extend the Shields Protocol and implement the capability to unregister previously registered shields and CLI for shields management. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #2581 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> First of, test API for shields 1. Install and start Ollama: `ollama serve` 2. Pull Llama Guard Model in Ollama: `ollama pull llama-guard3:8b` 3. Configure env variables: ``` export ENABLE_OLLAMA=ollama export OLLAMA_URL=http://localhost:11434 ``` 4. Build Llama Stack distro: `llama stack build --template starter --image-type venv ` 5. Start Llama Stack server: `llama stack run starter --port 8321` 6. Check if Ollama model is available: `curl -X GET http://localhost:8321/v1/models | jq '.data[] | select(.provider_id=="ollama")'` 7. Register a new Shield using Ollama provider: ``` curl -X POST http://localhost:8321/v1/shields \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "provider_id": "llama-guard", "provider_shield_id": "ollama/llama-guard3:8b", "params": {} }' ``` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 8. Check if shield was registered: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 9. Run shield: ``` curl -X POST http://localhost:8321/v1/safety/run-shield \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "messages": [ { "role": "user", "content": "How can I hack into someone computer?" } ], "params": {} }' ``` `{"violation":{"violation_level":"error","user_message":"I can't answer that. Can I help with something else?","metadata":{"violation_type":"S2"}}}% ` 10. Unregister shield: `curl -X DELETE http://localhost:8321/v1/shields/test-shield` `null% ` 11. Verify shield was deleted: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"detail":"Invalid value: Shield 'test-shield' not found"}%` All tests passed ✅ ``` ========================================================================== 430 passed, 194 warnings in 19.54s ========================================================================== /Users/iamiller/GitHub/llama-stack/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/async_client_cleanup.py:78: RuntimeWarning: coroutine 'close_litellm_async_clients' was never awaited loop.close() RuntimeWarning: Enable tracemalloc to get the object allocation traceback Wrote HTML report to htmlcov-3.12/index.html ```
104 lines
3.9 KiB
Python
104 lines
3.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import litellm
|
|
import requests
|
|
|
|
from llama_stack.apis.inference import Message
|
|
from llama_stack.apis.safety import (
|
|
RunShieldResponse,
|
|
Safety,
|
|
SafetyViolation,
|
|
ViolationLevel,
|
|
)
|
|
from llama_stack.apis.shields import Shield
|
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
|
|
|
from .config import SambaNovaSafetyConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|
|
|
|
|
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
|
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
|
self.config = config
|
|
self.environment_available_models = []
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
def _get_api_key(self) -> str:
|
|
config_api_key = self.config.api_key if self.config.api_key else None
|
|
if config_api_key:
|
|
return config_api_key.get_secret_value()
|
|
else:
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None or not provider_data.sambanova_api_key:
|
|
raise ValueError(
|
|
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
|
|
)
|
|
return provider_data.sambanova_api_key
|
|
|
|
async def register_shield(self, shield: Shield) -> None:
|
|
list_models_url = self.config.url + "/models"
|
|
if len(self.environment_available_models) == 0:
|
|
try:
|
|
response = requests.get(list_models_url)
|
|
response.raise_for_status()
|
|
except requests.exceptions.RequestException as e:
|
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
|
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
|
if (
|
|
"guard" not in shield.provider_resource_id.lower()
|
|
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
|
):
|
|
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
|
|
|
async def unregister_shield(self, identifier: str) -> None:
|
|
pass
|
|
|
|
async def run_shield(
|
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
|
) -> RunShieldResponse:
|
|
shield = await self.shield_store.get_shield(shield_id)
|
|
if not shield:
|
|
raise ValueError(f"Shield {shield_id} not found")
|
|
|
|
shield_params = shield.params
|
|
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
|
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
|
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
|
|
|
response = litellm.completion(
|
|
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
|
)
|
|
shield_message = response.choices[0].message.content
|
|
|
|
if "unsafe" in shield_message.lower():
|
|
user_message = CANNED_RESPONSE_TEXT
|
|
violation_type = shield_message.split("\n")[-1]
|
|
metadata = {"violation_type": violation_type}
|
|
|
|
return RunShieldResponse(
|
|
violation=SafetyViolation(
|
|
user_message=user_message,
|
|
violation_level=ViolationLevel.ERROR,
|
|
metadata=metadata,
|
|
)
|
|
)
|
|
|
|
return RunShieldResponse()
|