improvements to include other ollama methods like ps, list and pull

This commit is contained in:
Ashwin Bharambe 2025-07-29 10:56:57 -07:00
parent 1a21c4b695
commit b578f9aec1
23 changed files with 6748 additions and 5711 deletions

View file

@ -185,95 +185,70 @@ def llama_stack_client(request, provider_data):
if not config:
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
# Set up inference recording if enabled
inference_mode = os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
recording_context = None
# Handle server:<config_name> format or server:<config_name>:<port>
if config.startswith("server:"):
parts = config.split(":")
config_name = parts[1]
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
base_url = f"http://localhost:{port}"
if inference_mode in ["record", "replay"]:
from llama_stack.testing.inference_recorder import setup_inference_recording
# Check if port is available
if is_port_available(port):
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
recording_context = setup_inference_recording()
recording_context.__enter__()
print(f"Inference recording enabled: mode={inference_mode}")
# Start server
server_process = start_llama_stack_server(config_name)
try:
# Handle server:<config_name> format or server:<config_name>:<port>
if config.startswith("server:"):
parts = config.split(":")
config_name = parts[1]
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
base_url = f"http://localhost:{port}"
# Check if port is available
if is_port_available(port):
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
# Start server
server_process = start_llama_stack_server(config_name)
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
f"See server.log for details."
)
print(f"Server is ready at {base_url}")
# Store process for potential cleanup (pytest will handle termination at session end)
request.session._llama_stack_server_process = server_process
else:
print(f"Port {port} is already in use, assuming server is already running...")
client = LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
)
else:
# check if this looks like a URL using proper URL parsing
try:
parsed_url = urlparse(config)
if parsed_url.scheme and parsed_url.netloc:
client = LlamaStackClient(
base_url=config,
provider_data=provider_data,
)
else:
raise ValueError("Not a URL")
except Exception as e:
# If URL parsing fails, treat as library config
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(), f)
config = run_config_file.name
client = LlamaStackAsLibraryClient(
config,
provider_data=provider_data,
skip_logger_removal=True,
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
f"See server.log for details."
)
if not client.initialize():
raise RuntimeError("Initialization failed") from e
# Store recording context for cleanup
if recording_context:
request.session._inference_recording_context = recording_context
print(f"Server is ready at {base_url}")
return client
# Store process for potential cleanup (pytest will handle termination at session end)
request.session._llama_stack_server_process = server_process
else:
print(f"Port {port} is already in use, assuming server is already running...")
return LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
)
# check if this looks like a URL using proper URL parsing
try:
parsed_url = urlparse(config)
if parsed_url.scheme and parsed_url.netloc:
return LlamaStackClient(
base_url=config,
provider_data=provider_data,
)
except Exception:
# Clean up recording context on error
if recording_context:
try:
recording_context.__exit__(None, None, None)
except Exception as cleanup_error:
print(f"Warning: Error cleaning up recording context: {cleanup_error}")
raise
# If URL parsing fails, treat as non-URL config
pass
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(), f)
config = run_config_file.name
client = LlamaStackAsLibraryClient(
config,
provider_data=provider_data,
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client
@pytest.fixture(scope="session")
@ -289,20 +264,9 @@ def compat_client(request):
@pytest.fixture(scope="session", autouse=True)
def cleanup_server_process(request):
"""Cleanup server process and inference recording at the end of the test session."""
"""Cleanup server process at the end of the test session."""
yield # Run tests
# Clean up inference recording context
if hasattr(request.session, "_inference_recording_context"):
recording_context = request.session._inference_recording_context
if recording_context:
try:
print("Cleaning up inference recording context...")
recording_context.__exit__(None, None, None)
except Exception as e:
print(f"Error during inference recording cleanup: {e}")
# Clean up server process
if hasattr(request.session, "_llama_stack_server_process"):
server_process = request.session._llama_stack_server_process
if server_process: