Fixes for safety provider added to nvidia distro

This commit is contained in:
Chantal D Gama Rose 2025-02-24 18:52:56 +00:00
parent 0f1a9d06db
commit b9564fb435
5 changed files with 8 additions and 46 deletions

View file

@ -7,7 +7,7 @@ distribution_spec:
vector_io: vector_io:
- inline::faiss - inline::faiss
safety: safety:
- inline::llama-guard - remote::nvidia
agents: agents:
- inline::meta-reference - inline::meta-reference
telemetry: telemetry:
@ -15,16 +15,9 @@ distribution_spec:
eval: eval:
- inline::meta-reference - inline::meta-reference
datasetio: datasetio:
- remote::huggingface
- inline::localfs - inline::localfs
scoring: scoring:
- inline::basic - inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol
image_type: conda image_type: conda

View file

@ -79,7 +79,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="remote_hosted", distro_type="remote_hosted",
description="Use NVIDIA NIM for running LLM inference", description="Use NVIDIA NIM for running LLM inference and safety",
container_image=None, container_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=Path(__file__).parent / "doc_template.md",
providers=providers, providers=providers,
@ -100,7 +100,7 @@ def get_distribution_template() -> DistributionTemplate:
] ]
}, },
default_models=[inference_model, safety_model], default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
}, },

View file

@ -31,8 +31,8 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: nvidia
provider_type: inline::llama-guard provider_type: remote::nvidia
config: {} config: {}
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
@ -54,9 +54,6 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: {} config: {}
datasetio: datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: {} config: {}
@ -64,33 +61,10 @@ providers:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {} config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime: tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
@ -105,16 +79,13 @@ models:
model_type: llm model_type: llm
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server: server:
port: 8321 port: 8321

View file

@ -110,8 +110,7 @@ models:
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: ollama provider_id: sentence-transformers
provider_model_id: all-minilm:latest
model_type: embedding model_type: embedding
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}

View file

@ -103,8 +103,7 @@ models:
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: ollama provider_id: sentence-transformers
provider_model_id: all-minilm:latest
model_type: embedding model_type: embedding
shields: [] shields: []
vector_dbs: [] vector_dbs: []