forked from phoenix-oss/llama-stack-mirror
		
	
		
			
				
	
	
		
			77 lines
		
	
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from .fixtures import INFERENCE_FIXTURES
 | |
| 
 | |
| 
 | |
| def pytest_addoption(parser):
 | |
|     parser.addoption(
 | |
|         "--inference-model",
 | |
|         action="store",
 | |
|         default=None,
 | |
|         help="Specify the inference model to use for testing",
 | |
|     )
 | |
| 
 | |
| 
 | |
| def pytest_configure(config):
 | |
|     for model in ["llama_8b", "llama_3b", "llama_vision"]:
 | |
|         config.addinivalue_line(
 | |
|             "markers", f"{model}: mark test to run only with the given model"
 | |
|         )
 | |
| 
 | |
|     for fixture_name in INFERENCE_FIXTURES:
 | |
|         config.addinivalue_line(
 | |
|             "markers",
 | |
|             f"{fixture_name}: marks tests as {fixture_name} specific",
 | |
|         )
 | |
| 
 | |
| 
 | |
| MODEL_PARAMS = [
 | |
|     pytest.param(
 | |
|         "meta-llama/Llama-3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"
 | |
|     ),
 | |
|     pytest.param(
 | |
|         "meta-llama/Llama-3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"
 | |
|     ),
 | |
| ]
 | |
| 
 | |
| VISION_MODEL_PARAMS = [
 | |
|     pytest.param(
 | |
|         "Llama3.2-11B-Vision-Instruct",
 | |
|         marks=pytest.mark.llama_vision,
 | |
|         id="llama_vision",
 | |
|     ),
 | |
| ]
 | |
| 
 | |
| 
 | |
| def pytest_generate_tests(metafunc):
 | |
|     if "inference_model" in metafunc.fixturenames:
 | |
|         model = metafunc.config.getoption("--inference-model")
 | |
|         if model:
 | |
|             params = [pytest.param(model, id="")]
 | |
|         else:
 | |
|             cls_name = metafunc.cls.__name__
 | |
|             if "Vision" in cls_name:
 | |
|                 params = VISION_MODEL_PARAMS
 | |
|             else:
 | |
|                 params = MODEL_PARAMS
 | |
| 
 | |
|         metafunc.parametrize(
 | |
|             "inference_model",
 | |
|             params,
 | |
|             indirect=True,
 | |
|         )
 | |
|     if "inference_stack" in metafunc.fixturenames:
 | |
|         metafunc.parametrize(
 | |
|             "inference_stack",
 | |
|             [
 | |
|                 pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
 | |
|                 for fixture_name in INFERENCE_FIXTURES
 | |
|             ],
 | |
|             indirect=True,
 | |
|         )
 |