diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c46100c38..1a8d6734f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -89,7 +89,7 @@ jobs: -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=all-MiniLM-L6-v2 \ - --safety-shield=ollama \ + --safety-shield=$SAFETY_MODEL \ --color=yes \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index 64ddb23d9..7c0a19a1a 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import ( build_model_entry, ) +SAFETY_MODELS_ENTRIES = [ + # The Llama Guard models don't have their full fp16 versions + # so we are going to alias their default version to the canonical SKU + build_hf_repo_model_entry( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), +] + MODEL_ENTRIES = [ build_hf_repo_model_entry( "llama3.1:8b-instruct-fp16", @@ -73,16 +86,6 @@ MODEL_ENTRIES = [ "llama3.3:70b", CoreModelId.llama3_3_70b_instruct.value, ), - # The Llama Guard models don't have their full fp16 versions - # so we are going to alias their default version to the canonical SKU - build_hf_repo_model_entry( - "llama-guard3:8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "llama-guard3:1b", - CoreModelId.llama_guard_3_1b.value, - ), ProviderModelEntry( provider_model_id="all-minilm:l6-v2", aliases=["all-minilm"], @@ -100,4 +103,4 @@ MODEL_ENTRIES = [ "context_length": 8192, }, ), -] +] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 4eccfb25c..e5c13aa74 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="nvidia", distro_type="self_hosted", diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 942905dae..56ee9c47d 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + [ + models, _ = get_model_registry(available_models) + default_models = models + [ ModelInput( model_id="meta-llama/Llama-3.3-70B-Instruct", provider_id="groq", diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 888a2c3bf..ad449cb1b 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -1171,24 +1171,8 @@ models: provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} model_type: embedding shields: -- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo} -- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B} +- shield_id: ${env.SAFETY_MODEL:=__disabled__} + provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index 6b8aa8974..c0ac44183 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import ( ModelInput, Provider, ProviderSpec, - ShieldInput, ToolGroupInput, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers from llama_stack.providers.remote.inference.anthropic.models import ( MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.anthropic.models import ( - SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.bedrock.models import ( MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.bedrock.models import ( - SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.cerebras.models import ( MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.cerebras.models import ( - SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.databricks.databricks import ( MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.databricks.databricks import ( - SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.fireworks.models import ( MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.fireworks.models import ( - SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.gemini.models import ( MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.gemini.models import ( - SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.groq.models import ( MODEL_ENTRIES as GROQ_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.groq.models import ( - SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.nvidia.models import ( MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.nvidia.models import ( - SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.openai.models import ( MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.openai.models import ( - SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.runpod.runpod import ( MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.runpod.runpod import ( - SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.sambanova.models import ( MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.sambanova.models import ( - SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.together.models import ( MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.together.models import ( - SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, @@ -111,6 +74,7 @@ from llama_stack.templates.template import ( DistributionTemplate, RunConfigSettings, get_model_registry, + get_shield_registry, ) @@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: """Get model entries for a specific provider type.""" safety_model_entries_map = { - "openai": OPENAI_SAFETY_MODELS_ENTRIES, - "fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES, - "together": TOGETHER_SAFETY_MODELS_ENTRIES, - "anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES, - "gemini": GEMINI_SAFETY_MODELS_ENTRIES, - "groq": GROQ_SAFETY_MODELS_ENTRIES, - "sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES, - "cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES, - "bedrock": BEDROCK_SAFETY_MODELS_ENTRIES, - "databricks": DATABRICKS_SAFETY_MODELS_ENTRIES, - "nvidia": NVIDIA_SAFETY_MODELS_ENTRIES, - "runpod": RUNPOD_SAFETY_MODELS_ENTRIES, - } - - # Special handling for providers with dynamic model entries - if provider_type == "ollama": - return [ + "ollama": [ ProviderModelEntry( - provider_model_id="llama-guard3:1b", + provider_model_id="${env.SAFETY_MODEL:=__disabled__}", model_type=ModelType.llm, ), - ] + ], + } return safety_model_entries_map.get(provider_type, []) @@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro # build a list of shields for all possible providers -def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: - shields = [] +def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]: + available_models = {} for provider in providers: provider_type = provider.provider_type.split("::")[1] safety_model_entries = _get_model_safety_entries_for_provider(provider_type) if len(safety_model_entries) == 0: continue - if provider.provider_id: - shield_id = provider.provider_id - else: - raise ValueError(f"Provider {provider.provider_type} has no provider_id") - for safety_model_entry in safety_model_entries: - print(f"provider.provider_id: {provider.provider_id}") - print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") - shields.append( - ShieldInput( - provider_id="llama-guard", - shield_id=shield_id, - provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}", - ) - ) - return shields + + env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}" + provider_id = f"${{env.{env_var}:=__disabled__}}" + + available_models[provider_id] = safety_model_entries + + return available_models def get_distribution_template() -> DistributionTemplate: @@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate: ), ] - shields = get_shields_for_providers(remote_inference_providers) - providers = { "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "vector_io": ([p.provider_type for p in vector_io_providers]), @@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, ids_conflict_in_models = get_model_registry(available_models) + + available_safety_models = get_safety_models_for_providers(remote_inference_providers) + shields = get_shield_registry(available_safety_models, ids_conflict_in_models) return DistributionTemplate( name=name, diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index dceb13c8b..fb2528873 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], -) -> list[ModelInput]: +) -> tuple[list[ModelInput], bool]: models = [] # check for conflicts in model ids @@ -74,7 +74,50 @@ def get_model_registry( metadata=entry.metadata, ) ) - return models + return models, ids_conflict + + +def get_shield_registry( + available_safety_models: dict[str, list[ProviderModelEntry]], + ids_conflict_in_models: bool, +) -> list[ShieldInput]: + shields = [] + + # check for conflicts in shield ids + all_ids = set() + ids_conflict = False + + for _, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + if model_id in all_ids: + ids_conflict = True + rich.print( + f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]" + ) + break + all_ids.update(ids) + if ids_conflict: + break + if ids_conflict: + break + + for provider_id, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id + shields.append( + ShieldInput( + shield_id=identifier, + provider_shield_id=f"{provider_id}/{entry.provider_model_id}" + if ids_conflict_in_models + else entry.provider_model_id, + ) + ) + + return shields class DefaultModel(BaseModel): diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index 7fa3a55e5..ea185f05d 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="watsonx", distro_type="remote_hosted", diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 66c9ab829..05549cf18 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id): return agent_config +@pytest.fixture(scope="session") +def agent_config_without_safety(text_model_id): + agent_config = dict( + model=text_model_id, + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 0.0001, + "top_p": 0.9, + }, + }, + tools=[], + enable_session_persistence=False, + ) + return agent_config + + def test_agent_simple(llama_stack_client, agent_config): agent = Agent(llama_stack_client, **agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): assert expected_kw in response.output_message.content.lower() -def test_rag_agent_with_attachments(llama_stack_client, agent_config): +def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety): urls = ["llama3.rst", "lora_finetune.rst"] documents = [ # passign as url @@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): metadata={}, ), ] - rag_agent = Agent(llama_stack_client, **agent_config) + rag_agent = Agent(llama_stack_client, **agent_config_without_safety) session_id = rag_agent.create_session(f"test-session-{uuid4()}") - user_prompts = [ - ( - "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", - "grouped", - ), - ] user_prompts = [ ( "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", @@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): assert "lora" in response.output_message.content.lower() -@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") -def test_rag_and_code_agent(llama_stack_client, agent_config): - if "llama-4" in agent_config["model"].lower(): - pytest.xfail("Not working for llama4") - - documents = [] - documents.append( - Document( - document_id="nba_wiki", - content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", - metadata={}, - ) - ) - documents.append( - Document( - document_id="perplexity_wiki", - content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: - - Srinivas, the CEO, worked at OpenAI as an AI researcher. - Konwinski was among the founding team at Databricks. - Yarats, the CTO, was an AI research scientist at Meta. - Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", - metadata={}, - ) - ) - vector_db_id = f"test-vector-db-{uuid4()}" - llama_stack_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - ) - llama_stack_client.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=vector_db_id, - chunk_size_in_tokens=128, - ) - agent_config = { - **agent_config, - "tools": [ - dict( - name="builtin::rag/knowledge_search", - args={"vector_db_ids": [vector_db_id]}, - ), - "builtin::code_interpreter", - ], - } - agent = Agent(llama_stack_client, **agent_config) - user_prompts = [ - ( - "when was Perplexity the company founded?", - [], - "knowledge_search", - "2022", - ), - ( - "when was the nba created?", - [], - "knowledge_search", - "1949", - ), - ] - - for prompt, docs, tool_name, expected_kw in user_prompts: - session_id = agent.create_session(f"test-session-{uuid4()}") - response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=session_id, - documents=docs, - stream=False, - ) - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}" - if expected_kw: - assert expected_kw in response.output_message.content.lower() - - @pytest.mark.parametrize( "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],