mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 03:10:03 +00:00
fix: Safety in starter (#2731)
- fireworks, together do not support Llama-guard 3 8b model anymore - Need to default to ollama - current safety shields logic was not correct since the shield_id was the provider ( which had duplicates ) - Followed similar logic to models Note: Seems a bit over-engineered but this can now be extended to other providers and fits in the overall mechanism of how env_vars are used to manage starter. ### How to test ``` ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2 ``` ### Related but not obvious in this PR In the llama-stack-ops repo, we run tests before publishing packages and docker containers. The actions in that repo were using the fireworks / together distros ( which are non-existent ) So need to update that to run with `starter` and use `ollama` specifically for safety.
This commit is contained in:
parent
6ad22c209f
commit
6b8a8c1be9
9 changed files with 104 additions and 195 deletions
|
@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
return agent_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agent_config_without_safety(text_model_id):
|
||||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
},
|
||||
tools=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
return agent_config
|
||||
|
||||
|
||||
def test_agent_simple(llama_stack_client, agent_config):
|
||||
agent = Agent(llama_stack_client, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
|
||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
# passign as url
|
||||
|
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
|||
metadata={},
|
||||
),
|
||||
]
|
||||
rag_agent = Agent(llama_stack_client, **agent_config)
|
||||
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
||||
"grouped",
|
||||
),
|
||||
]
|
||||
user_prompts = [
|
||||
(
|
||||
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||
|
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
|||
assert "lora" in response.output_message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||
if "llama-4" in agent_config["model"].lower():
|
||||
pytest.xfail("Not working for llama4")
|
||||
|
||||
documents = []
|
||||
documents.append(
|
||||
Document(
|
||||
document_id="nba_wiki",
|
||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
documents.append(
|
||||
Document(
|
||||
document_id="perplexity_wiki",
|
||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
||||
|
||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
||||
Konwinski was among the founding team at Databricks.
|
||||
Yarats, the CTO, was an AI research scientist at Meta.
|
||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=128,
|
||||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={"vector_db_ids": [vector_db_id]},
|
||||
),
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, **agent_config)
|
||||
user_prompts = [
|
||||
(
|
||||
"when was Perplexity the company founded?",
|
||||
[],
|
||||
"knowledge_search",
|
||||
"2022",
|
||||
),
|
||||
(
|
||||
"when was the nba created?",
|
||||
[],
|
||||
"knowledge_search",
|
||||
"1949",
|
||||
),
|
||||
]
|
||||
|
||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
documents=docs,
|
||||
stream=False,
|
||||
)
|
||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
||||
if expected_kw:
|
||||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_tools",
|
||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue