add default toolgroups to all providers (#795)

# What does this PR do?

Add toolgroup defs to all the distribution templates
This commit is contained in:
Dinesh Yeduguru 2025-01-16 16:54:59 -08:00 committed by GitHub
parent e88faa91e2
commit 73215460ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 63 additions and 8 deletions

View file

@ -100,6 +100,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={

View file

@ -104,4 +104,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -105,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={

View file

@ -105,4 +105,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -8,10 +8,9 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -50,6 +49,20 @@ def get_distribution_template() -> DistributionTemplate:
) )
for m in _MODEL_ALIASES for m in _MODEL_ALIASES
] ]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
@ -65,6 +78,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider], "inference": [inference_provider],
}, },
default_models=default_models, default_models=default_models,
default_tool_groups=default_tool_groups,
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -137,4 +137,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -101,6 +101,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={

View file

@ -103,4 +103,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -80,4 +80,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -99,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={

View file

@ -103,4 +103,10 @@ memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []
tool_groups: [] tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -103,6 +103,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={