Merge branch 'meta-llama:main' into add-unit-tests-and-fix-cli

This commit is contained in:
Courtney Pacheco 2025-03-31 21:17:48 -04:00 committed by GitHub
commit 696bcf6051
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
459 changed files with 39114 additions and 10751 deletions

View file

@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"

View file

@ -25,19 +25,21 @@ from llama_stack.models.llama.llama3.prompt_templates import (
class PromptTemplateTests(unittest.TestCase):
def check_generator_output(self, generator, expected_text):
example = generator.data_examples()[0]
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
def check_generator_output(self, generator):
for example in generator.data_examples():
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
self.check_generator_output(generator, expected_text)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
@ -47,143 +49,24 @@ class PromptTemplateTests(unittest.TestCase):
Tools: brave_search, wolfram_alpha
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
self.check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
self.check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
self.check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Overriding message.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]"""
)
user_system_prompt = textwrap.dedent(
"""
Overriding message.
@ -195,4 +78,5 @@ class PromptTemplateTests(unittest.TestCase):
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
assert "Overriding message." in text
assert '"name": "get_weather"' in text

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
yield agent_persistence
shutil.rmtree(temp_dir)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)

View file

@ -187,8 +187,8 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 100ms
loop.slow_callback_duration = 0.1
# Log when event loop is blocked for more than 200ms
loop.slow_callback_duration = 0.5
# Sleep for 500ms in our delayed http response
sleep_time = 0.5

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
@pytest.fixture(scope="session", autouse=True)
def patch_aiohttp_session():
with patch("aiohttp.ClientSession", return_value=mock_session):
yield
@pytest.fixture
def event_loop():
"""Create and provide a new event loop for each test."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture
def run_async():
"""Fixture to run async functions in tests."""
def _run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return _run_async

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value
else:
assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim,
adapter_dropout=0.2,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
}
)
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=training_config,
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,295 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry, providable_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:
"""Test suite for testing provider configurations across all API types."""
@pytest.mark.parametrize("api", providable_apis())
def test_api_providers(self, api):
provider_registry = get_provider_registry()
providers = provider_registry.get(api, {})
failures = []
for provider_type, provider_spec in providers.items():
try:
self._verify_provider_config(provider_type, provider_spec)
except Exception as e:
failures.append(f"Failed to verify {provider_type} config: {str(e)}")
if failures:
pytest.fail("\n".join(failures))
def _verify_provider_config(self, provider_type, provider_spec):
"""Helper method to verify a single provider configuration."""
# Get the config class
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
assert issubclass(config_type, BaseModel), f"{config_class_name} is not a subclass of BaseModel"
assert hasattr(config_type, "sample_run_config"), f"{config_class_name} does not have sample_run_config method"
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import numpy as np
import pytest
from llama_stack.apis.vector_io import Chunk
EMBEDDING_DIMENSION = 384
@pytest.fixture
def vector_db_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@pytest.fixture(scope="session")
def embedding_dimension() -> int:
return EMBEDDING_DIMENSION
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorDB,
VectorDBStore,
)
from llama_stack.providers.inline.vector_io.qdrant.config import (
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_qdrant.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
return mock_vector_db
@pytest.fixture
def mock_vector_db_store(mock_vector_db) -> MagicMock:
mock_store = MagicMock(spec=VectorDBStore)
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
return mock_store
@pytest.fixture
def mock_api_service(sample_embeddings):
mock_api_service = MagicMock(spec=Inference)
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
return mock_api_service
@pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
await adapter.shutdown()
__QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
vector_db_id,
sample_chunks,
sample_embeddings,
max_query_chunks,
expected_chunks,
) -> None:
assert qdrant_adapter is not None
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
assert index is not None
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
params={"max_chunks": max_query_chunks},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
# To by-pass attempt to convert a Mock to JSON
def _prepare_for_json(value: Any) -> str:
return str(value)
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,
sample_chunks,
) -> None:
# Initially, no collections
vector_db_id = mock_vector_db.identifier
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
# Register does not create a collection
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
await qdrant_adapter.register_vector_db(mock_vector_db)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
# First insert creates the collection
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
assert await qdrant_adapter.client.collection_exists(vector_db_id)
# Unregister deletes the collection
await qdrant_adapter.unregister_vector_db(vector_db_id)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
assert len((await qdrant_adapter.client.get_collections()).collections) == 0

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import pytest_asyncio
import sqlite_vec
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
generate_chunk_id,
)
@ -29,8 +28,6 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
EMBEDDING_DIMENSION = 384
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
@pytest.fixture(scope="session")
@ -38,74 +35,53 @@ def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"

View file

@ -12,6 +12,7 @@ import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_vector_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(kvstore)
await test_registry.initialize()
# Get all objects, which should only return the valid one
all_objects = await test_registry.get_all()
# Should have filtered out the invalid entries
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db"
# Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields")
assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_cached_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(kvstore)
await cached_registry.initialize()
all_objects = await cached_registry.get_all()
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
assert invalid_obj is None

View file

@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth import AccessAttributes
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
async def kvstore():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_registry_acl.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
yield kvstore
shutil.rmtree(temp_dir)
@pytest.fixture(scope="function")
async def registry(kvstore):
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
return registry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(registry):
model = ModelWithACL(
identifier="model-acl",
provider_id="test-provider",
provider_resource_id="model-acl-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
)
success = await registry.register(model)
assert success
cached_model = registry.get_cached("model", "model-acl")
assert cached_model is not None
assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"]
fetched_model = await registry.get("model", "model-acl")
assert fetched_model is not None
assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await registry.update(model)
updated_cached = registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
new_model = await new_registry.get("model", "model-acl")
assert new_model is not None
assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"]
assert new_model.access_attributes.projects == ["project-x"]
assert new_model.access_attributes.teams is None
@pytest.mark.asyncio
async def test_registry_empty_acl(registry):
model = ModelWithACL(
identifier="model-empty-acl",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
)
await registry.register(model)
cached_model = registry.get_cached("model", "model-empty-acl")
assert cached_model is not None
assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
all_models = await registry.get_all()
assert len(all_models) == 1
model = ModelWithACL(
identifier="model-no-acl",
provider_id="test-provider",
provider_resource_id="model-resource-2",
model_type=ModelType.llm,
)
await registry.register(model)
cached_model = registry.get_cached("model", "model-no-acl")
assert cached_model is not None
assert cached_model.access_attributes is None
all_models = await registry.get_all()
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
model = ModelWithACL(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=attributes,
)
await registry.register(model)
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
from unittest.mock import MagicMock, Mock, patch
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
def _return_model(model):
return model
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=registry,
)
yield registry, routing_table
shutil.rmtree(temp_dir)
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public",
provider_id="test_provider",
provider_resource_id="model-public",
model_type=ModelType.llm,
)
model_admin_only = ModelWithACL(
identifier="model-admin",
provider_id="test_provider",
provider_resource_id="model-admin",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
)
model_data_scientist = ModelWithACL(
identifier="model-data-scientist",
provider_id="test_provider",
provider_resource_id="model-data-scientist",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
)
await registry.register(model_public)
await registry.register(model_admin_only)
await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model = await routing_table.get_model("model-public")
assert model.identifier == "model-public"
model = await routing_table.get_model("model-admin")
assert model.identifier == "model-admin"
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
model = await routing_table.get_model("model-public")
assert model.identifier == "model-public"
with pytest.raises(ValueError):
await routing_table.get_model("model-admin")
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data]
assert "model-public" in model_ids
assert "model-data-scientist" in model_ids
assert "model-admin" not in model_ids
model = await routing_table.get_model("model-public")
assert model.identifier == "model-public"
model = await routing_table.get_model("model-data-scientist")
assert model.identifier == "model-data-scientist"
with pytest.raises(ValueError):
await routing_table.get_model("model-admin")
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-updates",
provider_id="test_provider",
provider_resource_id="model-updates",
model_type=ModelType.llm,
)
await registry.register(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
await registry.update(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
with pytest.raises(ValueError):
await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
}
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
identifier="model-empty-attrs",
provider_id="test_provider",
provider_resource_id="model-empty-attrs",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {
"roles": [],
}
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
model_ids = [m.identifier for m in all_models.data]
assert "model-empty-attrs" in model_ids
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public-2",
provider_id="test_provider",
provider_resource_id="model-public-2",
model_type=ModelType.llm,
)
model_restricted = ModelWithACL(
identifier="model-restricted",
provider_id="test_provider",
provider_resource_id="model-restricted",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
)
await registry.register(model_public)
await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None
model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-restricted")
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public-2"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
# Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes
# Create model without explicit access attributes
model = ModelWithACL(
identifier="auto-access-model",
provider_id="test_provider",
provider_resource_id="auto-access-model",
model_type=ModelType.llm,
)
await routing_table.register_object(model)
# Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None
assert registered_model.access_attributes.roles == ["data-scientist"]
assert registered_model.access_attributes.teams == ["ml-team"]
assert registered_model.access_attributes.projects == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")
# But a user with matching attributes can
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"

View file

@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.server.auth import AuthenticationMiddleware
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data
@pytest.fixture
def mock_auth_endpoint():
return "http://mock-auth-service/validate"
@pytest.fixture
def valid_api_key():
return "valid_api_key_12345"
@pytest.fixture
def invalid_api_key():
return "invalid_api_key_67890"
@pytest.fixture
def app(mock_auth_endpoint):
app = FastAPI()
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def client(app):
return TestClient(app)
@pytest.fixture
def mock_scope():
return {
"type": "http",
"path": "/models/list",
"headers": [
(b"content-type", b"application/json"),
(b"authorization", b"Bearer test-api-key"),
(b"user-agent", b"test-user-agent"),
],
"query_string": b"limit=100&offset=0",
}
@pytest.fixture
def mock_middleware(mock_auth_endpoint):
mock_app = AsyncMock()
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
async def mock_post_success(*args, **kwargs):
return MockResponse(200, {"message": "Authentication successful"})
async def mock_post_failure(*args, **kwargs):
return MockResponse(401, {"message": "Authentication failed"})
async def mock_post_exception(*args, **kwargs):
raise Exception("Connection error")
def test_missing_auth_header(client):
response = client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
def test_invalid_auth_header_format(client):
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_success)
def test_valid_authentication(client, valid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("httpx.AsyncClient.post", new=mock_post_failure)
def test_invalid_authentication(client, invalid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Authentication failed" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_exception)
def test_auth_service_error(client, valid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 401
assert "Authentication service error" in response.json()["error"]["message"]
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(200, {"message": "Authentication successful"})
mock_post.return_value = mock_response
client.get(
"/test?param1=value1&param2=value2",
headers={
"Authorization": f"Bearer {valid_api_key}",
"User-Agent": "TestClient",
"Content-Type": "application/json",
},
)
# Check that the auth endpoint was called with the correct payload
call_args = mock_post.call_args
assert call_args is not None
url, kwargs = call_args[0][0], call_args[1]
assert url == mock_auth_endpoint
payload = kwargs["json"]
assert payload["api_key"] == valid_api_key
assert payload["request"]["path"] == "/test"
assert "authorization" not in payload["request"]["headers"]
assert "param1" in payload["request"]["params"]
assert "param2" in payload["request"]["params"]
@pytest.mark.asyncio
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
middleware, mock_app = mock_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"access_attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team"],
"projects": ["project-x", "project-y"],
}
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "namespaces" in attributes
assert attributes["namespaces"] == ["test-api-key"]