mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 18:08:09 +00:00
fix: Safety in starter (#2731)
- fireworks, together do not support Llama-guard 3 8b model anymore - Need to default to ollama - current safety shields logic was not correct since the shield_id was the provider ( which had duplicates ) - Followed similar logic to models Note: Seems a bit over-engineered but this can now be extended to other providers and fits in the overall mechanism of how env_vars are used to manage starter. ### How to test ``` ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2 ``` ### Related but not obvious in this PR In the llama-stack-ops repo, we run tests before publishing packages and docker containers. The actions in that repo were using the fireworks / together distros ( which are non-existent ) So need to update that to run with `starter` and use `ollama` specifically for safety.
This commit is contained in:
parent
6ad22c209f
commit
6b8a8c1be9
9 changed files with 104 additions and 195 deletions
|
@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
|||
|
||||
def get_model_registry(
|
||||
available_models: dict[str, list[ProviderModelEntry]],
|
||||
) -> list[ModelInput]:
|
||||
) -> tuple[list[ModelInput], bool]:
|
||||
models = []
|
||||
|
||||
# check for conflicts in model ids
|
||||
|
@ -74,7 +74,50 @@ def get_model_registry(
|
|||
metadata=entry.metadata,
|
||||
)
|
||||
)
|
||||
return models
|
||||
return models, ids_conflict
|
||||
|
||||
|
||||
def get_shield_registry(
|
||||
available_safety_models: dict[str, list[ProviderModelEntry]],
|
||||
ids_conflict_in_models: bool,
|
||||
) -> list[ShieldInput]:
|
||||
shields = []
|
||||
|
||||
# check for conflicts in shield ids
|
||||
all_ids = set()
|
||||
ids_conflict = False
|
||||
|
||||
for _, entries in available_safety_models.items():
|
||||
for entry in entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for model_id in ids:
|
||||
if model_id in all_ids:
|
||||
ids_conflict = True
|
||||
rich.print(
|
||||
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
||||
)
|
||||
break
|
||||
all_ids.update(ids)
|
||||
if ids_conflict:
|
||||
break
|
||||
if ids_conflict:
|
||||
break
|
||||
|
||||
for provider_id, entries in available_safety_models.items():
|
||||
for entry in entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for model_id in ids:
|
||||
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||
shields.append(
|
||||
ShieldInput(
|
||||
shield_id=identifier,
|
||||
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
||||
if ids_conflict_in_models
|
||||
else entry.provider_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
return shields
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue