Update provider_type -> inline::llama-guard in templates, update run.yaml

This commit is contained in:
Ashwin Bharambe 2024-11-11 09:12:17 -08:00
parent 15ffceb533
commit 4971113f92
24 changed files with 121 additions and 98 deletions

View file

@ -19,15 +19,14 @@ providers:
url: http://127.0.0.1:80 url: http://127.0.0.1:80
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -19,16 +19,16 @@ providers:
url: https://api.fireworks.ai/inference url: https://api.fireworks.ai/inference
# api_key: <ENTER_YOUR_API_KEY> # api_key: <ENTER_YOUR_API_KEY>
safety: safety:
safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -21,7 +21,7 @@ providers:
gpu_memory_utilization: 0.4 gpu_memory_utilization: 0.4
enforce_eager: true enforce_eager: true
max_tokens: 4096 max_tokens: 4096
- provider_id: vllm-safety - provider_id: vllm-inference-safety
provider_type: inline::vllm provider_type: inline::vllm
config: config:
model: Llama-Guard-3-1B model: Llama-Guard-3-1B
@ -31,14 +31,15 @@ providers:
max_tokens: 4096 max_tokens: 4096
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] # Uncomment to use prompt guard
# Uncomment to use prompt guard # - provider_id: meta1
# prompt_guard_shield: # provider_type: inline::prompt-guard
# model: Prompt-Guard-86M # config:
# model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -13,7 +13,7 @@ apis:
- safety - safety
providers: providers:
inference: inference:
- provider_id: meta-reference-inference - provider_id: inference0
provider_type: meta-reference provider_type: meta-reference
config: config:
model: Llama3.2-3B-Instruct model: Llama3.2-3B-Instruct
@ -21,7 +21,7 @@ providers:
torch_seed: null torch_seed: null
max_seq_len: 4096 max_seq_len: 4096
max_batch_size: 1 max_batch_size: 1
- provider_id: meta-reference-safety - provider_id: inference1
provider_type: meta-reference provider_type: meta-reference
config: config:
model: Llama-Guard-3-1B model: Llama-Guard-3-1B
@ -31,11 +31,14 @@ providers:
max_batch_size: 1 max_batch_size: 1
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
provider_type: inline::prompt-guard
config:
model: Prompt-Guard-86M
# Uncomment to use prompt guard # Uncomment to use prompt guard
# prompt_guard_shield: # prompt_guard_shield:
# model: Prompt-Guard-86M # model: Prompt-Guard-86M

View file

@ -22,17 +22,25 @@ providers:
torch_seed: null torch_seed: null
max_seq_len: 2048 max_seq_len: 2048
max_batch_size: 1 max_batch_size: 1
- provider_id: meta1
provider_type: meta-reference-quantized
config:
# not a quantized model !
model: Llama-Guard-3-1B
quantization: null
torch_seed: null
max_seq_len: 2048
max_batch_size: 1
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -19,15 +19,14 @@ providers:
url: http://127.0.0.1:14343 url: http://127.0.0.1:14343
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -19,15 +19,14 @@ providers:
url: http://127.0.0.1:14343 url: http://127.0.0.1:14343
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -19,15 +19,14 @@ providers:
url: http://127.0.0.1:8000 url: http://127.0.0.1:8000
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -19,15 +19,14 @@ providers:
url: http://127.0.0.1:5009 url: http://127.0.0.1:5009
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: meta-reference

View file

@ -20,15 +20,14 @@ providers:
# api_key: <ENTER_YOUR_API_KEY> # api_key: <ENTER_YOUR_API_KEY>
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: remote::weaviate provider_type: remote::weaviate

View file

@ -36,9 +36,9 @@ the provider types (implementations) you want to use for these APIs.
Tip: use <TAB> to see options for the providers. Tip: use <TAB> to see options for the providers.
> Enter provider for API inference: meta-reference > Enter provider for API inference: meta-reference
> Enter provider for API safety: meta-reference > Enter provider for API safety: inline::llama-guard
> Enter provider for API agents: meta-reference > Enter provider for API agents: meta-reference
> Enter provider for API memory: meta-reference > Enter provider for API memory: inline::faiss
> Enter provider for API datasetio: meta-reference > Enter provider for API datasetio: meta-reference
> Enter provider for API scoring: meta-reference > Enter provider for API scoring: meta-reference
> Enter provider for API eval: meta-reference > Enter provider for API eval: meta-reference
@ -203,8 +203,8 @@ distribution_spec:
description: Like local, but use ollama for running LLM inference description: Like local, but use ollama for running LLM inference
providers: providers:
inference: remote::ollama inference: remote::ollama
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -11,6 +11,7 @@ from llama_stack.apis.shields import ShieldType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
@ -44,6 +45,22 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
) )
# TODO: this is not tested yet; we would need to configure the run_shield() test
# and parametrize it with the "prompt" for testing depending on the safety fixture
# we are using.
@pytest.fixture(scope="session")
def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::prompt-guard",
provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_bedrock() -> ProviderFixture: def safety_bedrock() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(

View file

@ -3,7 +3,7 @@ distribution_spec:
description: Use Amazon Bedrock APIs. description: Use Amazon Bedrock APIs.
providers: providers:
inference: remote::bedrock inference: remote::bedrock
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -3,7 +3,7 @@ distribution_spec:
description: Use Databricks for running LLM inference description: Use Databricks for running LLM inference
providers: providers:
inference: remote::databricks inference: remote::databricks
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -6,6 +6,6 @@ distribution_spec:
memory: memory:
- meta-reference - meta-reference
- remote::weaviate - remote::weaviate
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -3,7 +3,7 @@ distribution_spec:
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints." description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
providers: providers:
inference: remote::hf::endpoint inference: remote::hf::endpoint
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -3,7 +3,7 @@ distribution_spec:
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference." description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
providers: providers:
inference: remote::hf::serverless inference: remote::hf::serverless
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -8,6 +8,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -8,6 +8,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -8,6 +8,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -7,6 +7,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -7,6 +7,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -7,6 +7,6 @@ distribution_spec:
- meta-reference - meta-reference
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference

View file

@ -6,6 +6,6 @@ distribution_spec:
memory: memory:
- meta-reference - meta-reference
- remote::weaviate - remote::weaviate
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference