mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			293 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			293 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import inspect
 | |
| import os
 | |
| import signal
 | |
| import socket
 | |
| import subprocess
 | |
| import tempfile
 | |
| import time
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| import pytest
 | |
| import requests
 | |
| import yaml
 | |
| from llama_stack_client import LlamaStackClient
 | |
| from openai import OpenAI
 | |
| 
 | |
| from llama_stack import LlamaStackAsLibraryClient
 | |
| from llama_stack.distribution.stack import run_config_from_adhoc_config_spec
 | |
| from llama_stack.env import get_env_or_fail
 | |
| 
 | |
| DEFAULT_PORT = 8321
 | |
| 
 | |
| 
 | |
| def is_port_available(port: int, host: str = "localhost") -> bool:
 | |
|     """Check if a port is available for binding."""
 | |
|     try:
 | |
|         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
 | |
|             sock.bind((host, port))
 | |
|             return True
 | |
|     except OSError:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def start_llama_stack_server(config_name: str) -> subprocess.Popen:
 | |
|     """Start a llama stack server with the given config."""
 | |
|     cmd = ["llama", "stack", "run", config_name]
 | |
|     devnull = open(os.devnull, "w")
 | |
|     process = subprocess.Popen(
 | |
|         cmd,
 | |
|         stdout=devnull,  # redirect stdout to devnull to prevent deadlock
 | |
|         stderr=subprocess.PIPE,  # keep stderr to see errors
 | |
|         text=True,
 | |
|         env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"},
 | |
|         # Create new process group so we can kill all child processes
 | |
|         preexec_fn=os.setsid,
 | |
|     )
 | |
|     return process
 | |
| 
 | |
| 
 | |
| def wait_for_server_ready(base_url: str, timeout: int = 30, process: subprocess.Popen | None = None) -> bool:
 | |
|     """Wait for the server to be ready by polling the health endpoint."""
 | |
|     health_url = f"{base_url}/v1/health"
 | |
|     start_time = time.time()
 | |
| 
 | |
|     while time.time() - start_time < timeout:
 | |
|         if process and process.poll() is not None:
 | |
|             print(f"Server process terminated with return code: {process.returncode}")
 | |
|             print(f"Server stderr: {process.stderr.read()}")
 | |
|             return False
 | |
| 
 | |
|         try:
 | |
|             response = requests.get(health_url, timeout=5)
 | |
|             if response.status_code == 200:
 | |
|                 return True
 | |
|         except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
 | |
|             pass
 | |
| 
 | |
|         # Print progress every 5 seconds
 | |
|         elapsed = time.time() - start_time
 | |
|         if int(elapsed) % 5 == 0 and elapsed > 0:
 | |
|             print(f"Waiting for server at {base_url}... ({elapsed:.1f}s elapsed)")
 | |
| 
 | |
|         time.sleep(0.5)
 | |
| 
 | |
|     print(f"Server failed to respond within {timeout} seconds")
 | |
|     return False
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def provider_data():
 | |
|     # TODO: this needs to be generalized so each provider can have a sample provider data just
 | |
|     # like sample run config on which we can do replace_env_vars()
 | |
|     keymap = {
 | |
|         "TAVILY_SEARCH_API_KEY": "tavily_search_api_key",
 | |
|         "BRAVE_SEARCH_API_KEY": "brave_search_api_key",
 | |
|         "FIREWORKS_API_KEY": "fireworks_api_key",
 | |
|         "GEMINI_API_KEY": "gemini_api_key",
 | |
|         "OPENAI_API_KEY": "openai_api_key",
 | |
|         "TOGETHER_API_KEY": "together_api_key",
 | |
|         "ANTHROPIC_API_KEY": "anthropic_api_key",
 | |
|         "GROQ_API_KEY": "groq_api_key",
 | |
|         "WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key",
 | |
|     }
 | |
|     provider_data = {}
 | |
|     for key, value in keymap.items():
 | |
|         if os.environ.get(key):
 | |
|             provider_data[value] = os.environ[key]
 | |
|     return provider_data
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def inference_provider_type(llama_stack_client):
 | |
|     providers = llama_stack_client.providers.list()
 | |
|     inference_providers = [p for p in providers if p.api == "inference"]
 | |
|     assert len(inference_providers) > 0, "No inference providers found"
 | |
|     return inference_providers[0].provider_type
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def client_with_models(
 | |
|     llama_stack_client,
 | |
|     text_model_id,
 | |
|     vision_model_id,
 | |
|     embedding_model_id,
 | |
|     embedding_dimension,
 | |
|     judge_model_id,
 | |
| ):
 | |
|     client = llama_stack_client
 | |
| 
 | |
|     providers = [p for p in client.providers.list() if p.api == "inference"]
 | |
|     assert len(providers) > 0, "No inference providers found"
 | |
|     inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
 | |
| 
 | |
|     model_ids = {m.identifier for m in client.models.list()}
 | |
|     model_ids.update(m.provider_resource_id for m in client.models.list())
 | |
| 
 | |
|     if text_model_id and text_model_id not in model_ids:
 | |
|         client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
 | |
|     if vision_model_id and vision_model_id not in model_ids:
 | |
|         client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
 | |
|     if judge_model_id and judge_model_id not in model_ids:
 | |
|         client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
 | |
| 
 | |
|     if embedding_model_id and embedding_model_id not in model_ids:
 | |
|         # try to find a provider that supports embeddings, if sentence-transformers is not available
 | |
|         selected_provider = None
 | |
|         for p in providers:
 | |
|             if p.provider_type == "inline::sentence-transformers":
 | |
|                 selected_provider = p
 | |
|                 break
 | |
| 
 | |
|         selected_provider = selected_provider or providers[0]
 | |
|         client.models.register(
 | |
|             model_id=embedding_model_id,
 | |
|             provider_id=selected_provider.provider_id,
 | |
|             model_type="embedding",
 | |
|             metadata={"embedding_dimension": embedding_dimension or 384},
 | |
|         )
 | |
|     return client
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def available_shields(llama_stack_client):
 | |
|     return [shield.identifier for shield in llama_stack_client.shields.list()]
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def model_providers(llama_stack_client):
 | |
|     return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
 | |
| 
 | |
| 
 | |
| @pytest.fixture(autouse=True)
 | |
| def skip_if_no_model(request):
 | |
|     model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id"]
 | |
|     test_func = request.node.function
 | |
| 
 | |
|     actual_params = inspect.signature(test_func).parameters.keys()
 | |
|     for fixture in model_fixtures:
 | |
|         # Only check fixtures that are actually in the test function's signature
 | |
|         if fixture in actual_params and fixture in request.fixturenames and not request.getfixturevalue(fixture):
 | |
|             pytest.skip(f"{fixture} empty - skipping test")
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def llama_stack_client(request, provider_data):
 | |
|     config = request.config.getoption("--stack-config")
 | |
|     if not config:
 | |
|         config = get_env_or_fail("LLAMA_STACK_CONFIG")
 | |
| 
 | |
|     if not config:
 | |
|         raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
 | |
| 
 | |
|     # Handle server:<config_name> format or server:<config_name>:<port>
 | |
|     if config.startswith("server:"):
 | |
|         parts = config.split(":")
 | |
|         config_name = parts[1]
 | |
|         port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
 | |
|         base_url = f"http://localhost:{port}"
 | |
| 
 | |
|         # Check if port is available
 | |
|         if is_port_available(port):
 | |
|             print(f"Starting llama stack server with config '{config_name}' on port {port}...")
 | |
| 
 | |
|             # Start server
 | |
|             server_process = start_llama_stack_server(config_name)
 | |
| 
 | |
|             # Wait for server to be ready
 | |
|             if not wait_for_server_ready(base_url, timeout=120, process=server_process):
 | |
|                 print("Server failed to start within timeout")
 | |
|                 server_process.terminate()
 | |
|                 raise RuntimeError(
 | |
|                     f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
 | |
|                     f"See server.log for details."
 | |
|                 )
 | |
| 
 | |
|             print(f"Server is ready at {base_url}")
 | |
| 
 | |
|             # Store process for potential cleanup (pytest will handle termination at session end)
 | |
|             request.session._llama_stack_server_process = server_process
 | |
|         else:
 | |
|             print(f"Port {port} is already in use, assuming server is already running...")
 | |
| 
 | |
|         return LlamaStackClient(
 | |
|             base_url=base_url,
 | |
|             provider_data=provider_data,
 | |
|             timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
 | |
|         )
 | |
| 
 | |
|     # check if this looks like a URL using proper URL parsing
 | |
|     try:
 | |
|         parsed_url = urlparse(config)
 | |
|         if parsed_url.scheme and parsed_url.netloc:
 | |
|             return LlamaStackClient(
 | |
|                 base_url=config,
 | |
|                 provider_data=provider_data,
 | |
|             )
 | |
|     except Exception:
 | |
|         # If URL parsing fails, treat as non-URL config
 | |
|         pass
 | |
| 
 | |
|     if "=" in config:
 | |
|         run_config = run_config_from_adhoc_config_spec(config)
 | |
|         run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
 | |
|         with open(run_config_file.name, "w") as f:
 | |
|             yaml.dump(run_config.model_dump(), f)
 | |
|         config = run_config_file.name
 | |
| 
 | |
|     client = LlamaStackAsLibraryClient(
 | |
|         config,
 | |
|         provider_data=provider_data,
 | |
|         skip_logger_removal=True,
 | |
|     )
 | |
|     if not client.initialize():
 | |
|         raise RuntimeError("Initialization failed")
 | |
| 
 | |
|     return client
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def openai_client(client_with_models):
 | |
|     base_url = f"{client_with_models.base_url}/v1/openai/v1"
 | |
|     return OpenAI(base_url=base_url, api_key="fake")
 | |
| 
 | |
| 
 | |
| @pytest.fixture(params=["openai_client", "llama_stack_client"])
 | |
| def compat_client(request):
 | |
|     return request.getfixturevalue(request.param)
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session", autouse=True)
 | |
| def cleanup_server_process(request):
 | |
|     """Cleanup server process at the end of the test session."""
 | |
|     yield  # Run tests
 | |
| 
 | |
|     if hasattr(request.session, "_llama_stack_server_process"):
 | |
|         server_process = request.session._llama_stack_server_process
 | |
|         if server_process:
 | |
|             if server_process.poll() is None:
 | |
|                 print("Terminating llama stack server process...")
 | |
|             else:
 | |
|                 print(f"Server process already terminated with return code: {server_process.returncode}")
 | |
|                 return
 | |
|             try:
 | |
|                 print(f"Terminating process {server_process.pid} and its group...")
 | |
|                 # Kill the entire process group
 | |
|                 os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
 | |
|                 server_process.wait(timeout=10)
 | |
|                 print("Server process and children terminated gracefully")
 | |
|             except subprocess.TimeoutExpired:
 | |
|                 print("Server process did not terminate gracefully, killing it")
 | |
|                 # Force kill the entire process group
 | |
|                 os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
 | |
|                 server_process.wait()
 | |
|                 print("Server process and children killed")
 | |
|             except Exception as e:
 | |
|                 print(f"Error during server cleanup: {e}")
 | |
|         else:
 | |
|             print("Server process not found - won't be able to cleanup")
 |