mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 17:44:01 +00:00
fix: Safety in starter (#2731)
- fireworks, together do not support Llama-guard 3 8b model anymore - Need to default to ollama - current safety shields logic was not correct since the shield_id was the provider ( which had duplicates ) - Followed similar logic to models Note: Seems a bit over-engineered but this can now be extended to other providers and fits in the overall mechanism of how env_vars are used to manage starter. ### How to test ``` ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2 ``` ### Related but not obvious in this PR In the llama-stack-ops repo, we run tests before publishing packages and docker containers. The actions in that repo were using the fireworks / together distros ( which are non-existent ) So need to update that to run with `starter` and use `ollama` specifically for safety.
This commit is contained in:
parent
6ad22c209f
commit
6b8a8c1be9
9 changed files with 104 additions and 195 deletions
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
@ -89,7 +89,7 @@ jobs:
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
||||||
--embedding-model=all-MiniLM-L6-v2 \
|
--embedding-model=all-MiniLM-L6-v2 \
|
||||||
--safety-shield=ollama \
|
--safety-shield=$SAFETY_MODEL \
|
||||||
--color=yes \
|
--color=yes \
|
||||||
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_entry,
|
build_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SAFETY_MODELS_ENTRIES = [
|
||||||
|
# The Llama Guard models don't have their full fp16 versions
|
||||||
|
# so we are going to alias their default version to the canonical SKU
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:1b",
|
||||||
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"llama3.1:8b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
|
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
|
||||||
"llama3.3:70b",
|
"llama3.3:70b",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
# The Llama Guard models don't have their full fp16 versions
|
|
||||||
# so we are going to alias their default version to the canonical SKU
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:8b",
|
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:1b",
|
|
||||||
CoreModelId.llama_guard_3_1b.value,
|
|
||||||
),
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="all-minilm:l6-v2",
|
provider_model_id="all-minilm:l6-v2",
|
||||||
aliases=["all-minilm"],
|
aliases=["all-minilm"],
|
||||||
|
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
|
||||||
"context_length": 8192,
|
"context_length": 8192,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name="nvidia",
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
|
|
|
@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models) + [
|
models, _ = get_model_registry(available_models)
|
||||||
|
default_models = models + [
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
provider_id="groq",
|
provider_id="groq",
|
||||||
|
|
|
@ -1171,24 +1171,8 @@ models:
|
||||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_id: llama-guard
|
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
|
|
||||||
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
|
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
ModelInput,
|
ModelInput,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
ShieldInput,
|
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
|
||||||
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
from llama_stack.providers.remote.inference.gemini.models import (
|
||||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
from llama_stack.providers.remote.inference.groq.models import (
|
||||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
from llama_stack.providers.remote.inference.openai.models import (
|
||||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
|
||||||
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
from llama_stack.providers.remote.inference.together.models import (
|
||||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
|
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
get_model_registry,
|
get_model_registry,
|
||||||
|
get_shield_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
||||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||||
"""Get model entries for a specific provider type."""
|
"""Get model entries for a specific provider type."""
|
||||||
safety_model_entries_map = {
|
safety_model_entries_map = {
|
||||||
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
|
"ollama": [
|
||||||
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
"groq": GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Special handling for providers with dynamic model entries
|
|
||||||
if provider_type == "ollama":
|
|
||||||
return [
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="llama-guard3:1b",
|
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|
||||||
return safety_model_entries_map.get(provider_type, [])
|
return safety_model_entries_map.get(provider_type, [])
|
||||||
|
|
||||||
|
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
||||||
|
|
||||||
|
|
||||||
# build a list of shields for all possible providers
|
# build a list of shields for all possible providers
|
||||||
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
|
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||||
shields = []
|
available_models = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
provider_type = provider.provider_type.split("::")[1]
|
provider_type = provider.provider_type.split("::")[1]
|
||||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||||
if len(safety_model_entries) == 0:
|
if len(safety_model_entries) == 0:
|
||||||
continue
|
continue
|
||||||
if provider.provider_id:
|
|
||||||
shield_id = provider.provider_id
|
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||||
else:
|
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||||
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
|
|
||||||
for safety_model_entry in safety_model_entries:
|
available_models[provider_id] = safety_model_entries
|
||||||
print(f"provider.provider_id: {provider.provider_id}")
|
|
||||||
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
|
return available_models
|
||||||
shields.append(
|
|
||||||
ShieldInput(
|
|
||||||
provider_id="llama-guard",
|
|
||||||
shield_id=shield_id,
|
|
||||||
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return shields
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
shields = get_shields_for_providers(remote_inference_providers)
|
|
||||||
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
||||||
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
||||||
|
@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||||
|
|
||||||
|
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||||
|
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
||||||
|
|
||||||
def get_model_registry(
|
def get_model_registry(
|
||||||
available_models: dict[str, list[ProviderModelEntry]],
|
available_models: dict[str, list[ProviderModelEntry]],
|
||||||
) -> list[ModelInput]:
|
) -> tuple[list[ModelInput], bool]:
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
# check for conflicts in model ids
|
# check for conflicts in model ids
|
||||||
|
@ -74,7 +74,50 @@ def get_model_registry(
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models, ids_conflict
|
||||||
|
|
||||||
|
|
||||||
|
def get_shield_registry(
|
||||||
|
available_safety_models: dict[str, list[ProviderModelEntry]],
|
||||||
|
ids_conflict_in_models: bool,
|
||||||
|
) -> list[ShieldInput]:
|
||||||
|
shields = []
|
||||||
|
|
||||||
|
# check for conflicts in shield ids
|
||||||
|
all_ids = set()
|
||||||
|
ids_conflict = False
|
||||||
|
|
||||||
|
for _, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
if model_id in all_ids:
|
||||||
|
ids_conflict = True
|
||||||
|
rich.print(
|
||||||
|
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
all_ids.update(ids)
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
|
||||||
|
for provider_id, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||||
|
shields.append(
|
||||||
|
ShieldInput(
|
||||||
|
shield_id=identifier,
|
||||||
|
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
||||||
|
if ids_conflict_in_models
|
||||||
|
else entry.provider_model_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return shields
|
||||||
|
|
||||||
|
|
||||||
class DefaultModel(BaseModel):
|
class DefaultModel(BaseModel):
|
||||||
|
|
|
@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="watsonx",
|
name="watsonx",
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
|
|
|
@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config_without_safety(text_model_id):
|
||||||
|
agent_config = dict(
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 0.0001,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client, agent_config):
|
def test_agent_simple(llama_stack_client, agent_config):
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
agent = Agent(llama_stack_client, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
|
||||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
# passign as url
|
# passign as url
|
||||||
|
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
rag_agent = Agent(llama_stack_client, **agent_config)
|
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
|
||||||
"grouped",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
assert "lora" in response.output_message.content.lower()
|
assert "lora" in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|
||||||
if "llama-4" in agent_config["model"].lower():
|
|
||||||
pytest.xfail("Not working for llama4")
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="nba_wiki",
|
|
||||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="perplexity_wiki",
|
|
||||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
|
||||||
|
|
||||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
|
||||||
Konwinski was among the founding team at Databricks.
|
|
||||||
Yarats, the CTO, was an AI research scientist at Meta.
|
|
||||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
|
||||||
documents=documents,
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
chunk_size_in_tokens=128,
|
|
||||||
)
|
|
||||||
agent_config = {
|
|
||||||
**agent_config,
|
|
||||||
"tools": [
|
|
||||||
dict(
|
|
||||||
name="builtin::rag/knowledge_search",
|
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
|
||||||
),
|
|
||||||
"builtin::code_interpreter",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"when was Perplexity the company founded?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"2022",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"when was the nba created?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"1949",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
session_id=session_id,
|
|
||||||
documents=docs,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
|
||||||
if expected_kw:
|
|
||||||
assert expected_kw in response.output_message.content.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"client_tools",
|
"client_tools",
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue