From 4971113f923597a39738c66f9b2e578d975089cd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 11 Nov 2024 09:12:17 -0800 Subject: [PATCH] Update provider_type -> inline::llama-guard in templates, update run.yaml --- distributions/dell-tgi/run.yaml | 15 ++++++------ distributions/fireworks/run.yaml | 16 ++++++------- distributions/inline-vllm/run.yaml | 17 ++++++------- distributions/meta-reference-gpu/run.yaml | 15 +++++++----- .../meta-reference-quantized-gpu/run.yaml | 24 ++++++++++++------- distributions/ollama-gpu/run.yaml | 15 ++++++------ distributions/ollama/run.yaml | 15 ++++++------ distributions/remote-vllm/run.yaml | 15 ++++++------ distributions/tgi/run.yaml | 15 ++++++------ distributions/together/run.yaml | 15 ++++++------ .../distribution_dev/building_distro.md | 8 +++---- .../providers/tests/safety/fixtures.py | 17 +++++++++++++ llama_stack/templates/bedrock/build.yaml | 4 ++-- llama_stack/templates/databricks/build.yaml | 4 ++-- llama_stack/templates/fireworks/build.yaml | 2 +- llama_stack/templates/hf-endpoint/build.yaml | 4 ++-- .../templates/hf-serverless/build.yaml | 4 ++-- llama_stack/templates/inline-vllm/build.yaml | 2 +- .../templates/meta-reference-gpu/build.yaml | 2 +- .../meta-reference-quantized-gpu/build.yaml | 2 +- llama_stack/templates/ollama/build.yaml | 2 +- llama_stack/templates/remote-vllm/build.yaml | 2 +- llama_stack/templates/tgi/build.yaml | 2 +- llama_stack/templates/together/build.yaml | 2 +- 24 files changed, 121 insertions(+), 98 deletions(-) diff --git a/distributions/dell-tgi/run.yaml b/distributions/dell-tgi/run.yaml index c5f6d0aaa..779750c58 100644 --- a/distributions/dell-tgi/run.yaml +++ b/distributions/dell-tgi/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:80 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/fireworks/run.yaml b/distributions/fireworks/run.yaml index 4363d86f3..1259c9493 100644 --- a/distributions/fireworks/run.yaml +++ b/distributions/fireworks/run.yaml @@ -19,16 +19,16 @@ providers: url: https://api.fireworks.ai/inference # api_key: safety: + safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/inline-vllm/run.yaml b/distributions/inline-vllm/run.yaml index aadf5c0ce..02499b49a 100644 --- a/distributions/inline-vllm/run.yaml +++ b/distributions/inline-vllm/run.yaml @@ -21,7 +21,7 @@ providers: gpu_memory_utilization: 0.4 enforce_eager: true max_tokens: 4096 - - provider_id: vllm-safety + - provider_id: vllm-inference-safety provider_type: inline::vllm config: model: Llama-Guard-3-1B @@ -31,14 +31,15 @@ providers: max_tokens: 4096 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] -# Uncomment to use prompt guard -# prompt_guard_shield: -# model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + # Uncomment to use prompt guard + # - provider_id: meta1 + # provider_type: inline::prompt-guard + # config: + # model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/meta-reference-gpu/run.yaml b/distributions/meta-reference-gpu/run.yaml index ad3187aa1..98a52bed1 100644 --- a/distributions/meta-reference-gpu/run.yaml +++ b/distributions/meta-reference-gpu/run.yaml @@ -13,7 +13,7 @@ apis: - safety providers: inference: - - provider_id: meta-reference-inference + - provider_id: inference0 provider_type: meta-reference config: model: Llama3.2-3B-Instruct @@ -21,7 +21,7 @@ providers: torch_seed: null max_seq_len: 4096 max_batch_size: 1 - - provider_id: meta-reference-safety + - provider_id: inference1 provider_type: meta-reference config: model: Llama-Guard-3-1B @@ -31,11 +31,14 @@ providers: max_batch_size: 1 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M # Uncomment to use prompt guard # prompt_guard_shield: # model: Prompt-Guard-86M diff --git a/distributions/meta-reference-quantized-gpu/run.yaml b/distributions/meta-reference-quantized-gpu/run.yaml index f162502c5..fa8be277d 100644 --- a/distributions/meta-reference-quantized-gpu/run.yaml +++ b/distributions/meta-reference-quantized-gpu/run.yaml @@ -22,17 +22,25 @@ providers: torch_seed: null max_seq_len: 2048 max_batch_size: 1 + - provider_id: meta1 + provider_type: meta-reference-quantized + config: + # not a quantized model ! + model: Llama-Guard-3-1B + quantization: null + torch_seed: null + max_seq_len: 2048 + max_batch_size: 1 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/ollama-gpu/run.yaml b/distributions/ollama-gpu/run.yaml index 798dabc0b..46c67a1e5 100644 --- a/distributions/ollama-gpu/run.yaml +++ b/distributions/ollama-gpu/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:14343 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/ollama/run.yaml b/distributions/ollama/run.yaml index 798dabc0b..46c67a1e5 100644 --- a/distributions/ollama/run.yaml +++ b/distributions/ollama/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:14343 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/remote-vllm/run.yaml b/distributions/remote-vllm/run.yaml index 2d0d36370..27d60bd6c 100644 --- a/distributions/remote-vllm/run.yaml +++ b/distributions/remote-vllm/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:8000 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/tgi/run.yaml b/distributions/tgi/run.yaml index dc8cb2d2d..dcbb69027 100644 --- a/distributions/tgi/run.yaml +++ b/distributions/tgi/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:5009 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/together/run.yaml b/distributions/together/run.yaml index 87fd4dcd7..36ef86056 100644 --- a/distributions/together/run.yaml +++ b/distributions/together/run.yaml @@ -20,15 +20,14 @@ providers: # api_key: safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: remote::weaviate diff --git a/docs/source/distribution_dev/building_distro.md b/docs/source/distribution_dev/building_distro.md index 314792e41..36c504b1b 100644 --- a/docs/source/distribution_dev/building_distro.md +++ b/docs/source/distribution_dev/building_distro.md @@ -36,9 +36,9 @@ the provider types (implementations) you want to use for these APIs. Tip: use to see options for the providers. > Enter provider for API inference: meta-reference -> Enter provider for API safety: meta-reference +> Enter provider for API safety: inline::llama-guard > Enter provider for API agents: meta-reference -> Enter provider for API memory: meta-reference +> Enter provider for API memory: inline::faiss > Enter provider for API datasetio: meta-reference > Enter provider for API scoring: meta-reference > Enter provider for API eval: meta-reference @@ -203,8 +203,8 @@ distribution_spec: description: Like local, but use ollama for running LLM inference providers: inference: remote::ollama - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference image_type: conda diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 5b4c07de5..10a6460cb 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -11,6 +11,7 @@ from llama_stack.apis.shields import ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig +from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 @@ -44,6 +45,22 @@ def safety_llama_guard(safety_model) -> ProviderFixture: ) +# TODO: this is not tested yet; we would need to configure the run_shield() test +# and parametrize it with the "prompt" for testing depending on the safety fixture +# we are using. +@pytest.fixture(scope="session") +def safety_prompt_guard() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="inline::prompt-guard", + provider_type="inline::prompt-guard", + config=PromptGuardConfig().model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def safety_bedrock() -> ProviderFixture: return ProviderFixture( diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index a3ff27949..44cc813ae 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: Use Amazon Bedrock APIs. providers: inference: remote::bedrock - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/databricks/build.yaml b/llama_stack/templates/databricks/build.yaml index f6c8b50a1..aa22f54b2 100644 --- a/llama_stack/templates/databricks/build.yaml +++ b/llama_stack/templates/databricks/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: Use Databricks for running LLM inference providers: inference: remote::databricks - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 5b662c213..833ce4ee2 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -6,6 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 6c84e5ccf..b06ee2eb0 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints." providers: inference: remote::hf::endpoint - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index 32561c1fa..62ff2c953 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference." providers: inference: remote::hf::serverless - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/inline-vllm/build.yaml b/llama_stack/templates/inline-vllm/build.yaml index d0fe93aa3..2e4b34bc6 100644 --- a/llama_stack/templates/inline-vllm/build.yaml +++ b/llama_stack/templates/inline-vllm/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index d0fe93aa3..2e4b34bc6 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml index 20500ea5a..8768bd430 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 06de2fc3c..410ae37cd 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index ea95992f3..967b64413 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index c5e618bb6..70c860001 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 05e59f677..614e31093 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -6,6 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference