# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import inspect import os import socket import subprocess import tempfile import time from urllib.parse import urlparse import pytest import requests import yaml from llama_stack_client import LlamaStackClient from openai import OpenAI from llama_stack import LlamaStackAsLibraryClient from llama_stack.distribution.stack import run_config_from_adhoc_config_spec from llama_stack.env import get_env_or_fail DEFAULT_PORT = 8321 def is_port_available(port: int, host: str = "localhost") -> bool: """Check if a port is available for binding.""" try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind((host, port)) return True except OSError: return False def start_llama_stack_server(config_name: str) -> subprocess.Popen: """Start a llama stack server with the given config.""" cmd = ["llama", "stack", "run", config_name] devnull = open(os.devnull, "w") process = subprocess.Popen( cmd, stdout=devnull, # redirect stdout to devnull to prevent deadlock stderr=devnull, # redirect stderr to devnull to prevent deadlock text=True, env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"}, ) return process def wait_for_server_ready(base_url: str, timeout: int = 30, process: subprocess.Popen | None = None) -> bool: """Wait for the server to be ready by polling the health endpoint.""" health_url = f"{base_url}/v1/health" start_time = time.time() while time.time() - start_time < timeout: if process and process.poll() is not None: print(f"Server process terminated with return code: {process.returncode}") return False try: response = requests.get(health_url, timeout=5) if response.status_code == 200: return True except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): pass # Print progress every 5 seconds elapsed = time.time() - start_time if int(elapsed) % 5 == 0 and elapsed > 0: print(f"Waiting for server at {base_url}... ({elapsed:.1f}s elapsed)") time.sleep(0.5) print(f"Server failed to respond within {timeout} seconds") return False @pytest.fixture(scope="session") def provider_data(): # TODO: this needs to be generalized so each provider can have a sample provider data just # like sample run config on which we can do replace_env_vars() keymap = { "TAVILY_SEARCH_API_KEY": "tavily_search_api_key", "BRAVE_SEARCH_API_KEY": "brave_search_api_key", "FIREWORKS_API_KEY": "fireworks_api_key", "GEMINI_API_KEY": "gemini_api_key", "OPENAI_API_KEY": "openai_api_key", "TOGETHER_API_KEY": "together_api_key", "ANTHROPIC_API_KEY": "anthropic_api_key", "GROQ_API_KEY": "groq_api_key", "WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key", } provider_data = {} for key, value in keymap.items(): if os.environ.get(key): provider_data[value] = os.environ[key] return provider_data @pytest.fixture(scope="session") def inference_provider_type(llama_stack_client): providers = llama_stack_client.providers.list() inference_providers = [p for p in providers if p.api == "inference"] assert len(inference_providers) > 0, "No inference providers found" return inference_providers[0].provider_type @pytest.fixture(scope="session") def client_with_models( llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id, ): client = llama_stack_client providers = [p for p in client.providers.list() if p.api == "inference"] assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] model_ids = {m.identifier for m in client.models.list()} model_ids.update(m.provider_resource_id for m in client.models.list()) if text_model_id and text_model_id not in model_ids: client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) if vision_model_id and vision_model_id not in model_ids: client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) if judge_model_id and judge_model_id not in model_ids: client.models.register(model_id=judge_model_id, provider_id=inference_providers[0]) if embedding_model_id and embedding_model_id not in model_ids: # try to find a provider that supports embeddings, if sentence-transformers is not available selected_provider = None for p in providers: if p.provider_type == "inline::sentence-transformers": selected_provider = p break selected_provider = selected_provider or providers[0] client.models.register( model_id=embedding_model_id, provider_id=selected_provider.provider_id, model_type="embedding", metadata={"embedding_dimension": embedding_dimension or 384}, ) return client @pytest.fixture(scope="session") def available_shields(llama_stack_client): return [shield.identifier for shield in llama_stack_client.shields.list()] @pytest.fixture(scope="session") def model_providers(llama_stack_client): return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"} @pytest.fixture(autouse=True) def skip_if_no_model(request): model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id"] test_func = request.node.function actual_params = inspect.signature(test_func).parameters.keys() for fixture in model_fixtures: # Only check fixtures that are actually in the test function's signature if fixture in actual_params and fixture in request.fixturenames and not request.getfixturevalue(fixture): pytest.skip(f"{fixture} empty - skipping test") @pytest.fixture(scope="session") def llama_stack_client(request, provider_data): config = request.config.getoption("--stack-config") if not config: config = get_env_or_fail("LLAMA_STACK_CONFIG") if not config: raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG") # Handle server: format or server:: if config.startswith("server:"): parts = config.split(":") config_name = parts[1] port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT)) base_url = f"http://localhost:{port}" # Check if port is available if is_port_available(port): print(f"Starting llama stack server with config '{config_name}' on port {port}...") # Start server server_process = start_llama_stack_server(config_name) # Wait for server to be ready if not wait_for_server_ready(base_url, timeout=30, process=server_process): print("Server failed to start within timeout") server_process.terminate() raise RuntimeError( f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. " f"See server.log for details." ) print(f"Server is ready at {base_url}") # Store process for potential cleanup (pytest will handle termination at session end) request.session._llama_stack_server_process = server_process else: print(f"Port {port} is already in use, assuming server is already running...") return LlamaStackClient( base_url=base_url, provider_data=provider_data, ) # check if this looks like a URL using proper URL parsing try: parsed_url = urlparse(config) if parsed_url.scheme and parsed_url.netloc: return LlamaStackClient( base_url=config, provider_data=provider_data, ) except Exception: # If URL parsing fails, treat as non-URL config pass if "=" in config: run_config = run_config_from_adhoc_config_spec(config) run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") with open(run_config_file.name, "w") as f: yaml.dump(run_config.model_dump(), f) config = run_config_file.name client = LlamaStackAsLibraryClient( config, provider_data=provider_data, skip_logger_removal=True, ) if not client.initialize(): raise RuntimeError("Initialization failed") return client @pytest.fixture(scope="session") def openai_client(client_with_models): base_url = f"{client_with_models.base_url}/v1/openai/v1" return OpenAI(base_url=base_url, api_key="fake") @pytest.fixture(scope="session", autouse=True) def cleanup_server_process(request): """Cleanup server process at the end of the test session.""" yield # Run tests if hasattr(request.session, "_llama_stack_server_process"): server_process = request.session._llama_stack_server_process if server_process: if server_process.poll() is None: print("Terminating llama stack server process...") else: print(f"Server process already terminated with return code: {server_process.returncode}") return try: server_process.terminate() server_process.wait(timeout=10) print("Server process terminated gracefully") except subprocess.TimeoutExpired: print("Server process did not terminate gracefully, killing it") server_process.kill() server_process.wait() print("Server process killed") except Exception as e: print(f"Error during server cleanup: {e}") else: print("Server process not found - won't be able to cleanup")