mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
fix(pr specific): passes pre-commit
This commit is contained in:
parent
4aa2dc110d
commit
2b7a765d02
20 changed files with 547 additions and 516 deletions
|
@ -14,7 +14,7 @@ HOW TO ADD A NEW MOCK SERVER:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -24,10 +24,10 @@ from .mock_base import MockServerBase
|
|||
class MockServerConfig(BaseModel):
|
||||
"""
|
||||
Configuration for a mock server to start.
|
||||
|
||||
|
||||
**TO ADD A NEW MOCK SERVER:**
|
||||
Just create a MockServerConfig instance with your server class.
|
||||
|
||||
|
||||
Example:
|
||||
MockServerConfig(
|
||||
name="Mock MyService",
|
||||
|
@ -35,73 +35,72 @@ class MockServerConfig(BaseModel):
|
|||
init_kwargs={"port": 9000, "config_param": "value"},
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
name: str = Field(description="Display name for logging")
|
||||
server_class: type = Field(description="Mock server class (must inherit from MockServerBase)")
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
|
||||
init_kwargs: dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
|
||||
|
||||
|
||||
async def start_mock_servers_async(mock_servers_config: List[MockServerConfig]) -> Dict[str, MockServerBase]:
|
||||
async def start_mock_servers_async(mock_servers_config: list[MockServerConfig]) -> dict[str, MockServerBase]:
|
||||
"""
|
||||
Start all mock servers in parallel and wait for them to be ready.
|
||||
|
||||
|
||||
**HOW IT WORKS:**
|
||||
1. Creates all server instances
|
||||
2. Calls await_start() on all servers in parallel
|
||||
3. Returns when all are ready
|
||||
|
||||
|
||||
**SIMPLE TO USE:**
|
||||
servers = await start_mock_servers_async([config1, config2, ...])
|
||||
|
||||
|
||||
Args:
|
||||
mock_servers_config: List of mock server configurations
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to server instance
|
||||
"""
|
||||
servers = {}
|
||||
start_tasks = []
|
||||
|
||||
|
||||
# Create all servers and prepare start tasks
|
||||
for config in mock_servers_config:
|
||||
server = config.server_class(**config.init_kwargs)
|
||||
servers[config.name] = server
|
||||
start_tasks.append(server.await_start())
|
||||
|
||||
|
||||
# Start all servers in parallel
|
||||
try:
|
||||
await asyncio.gather(*start_tasks)
|
||||
|
||||
|
||||
# Print readiness confirmation
|
||||
for name in servers.keys():
|
||||
print(f"[INFO] {name} ready")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If any server fails, stop all servers
|
||||
for server in servers.values():
|
||||
try:
|
||||
server.stop()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(f"Failed to start mock servers: {e}")
|
||||
|
||||
raise RuntimeError(f"Failed to start mock servers: {e}") from None
|
||||
|
||||
return servers
|
||||
|
||||
|
||||
def stop_mock_servers(servers: Dict[str, Any]):
|
||||
def stop_mock_servers(servers: dict[str, Any]):
|
||||
"""
|
||||
Stop all mock servers.
|
||||
|
||||
|
||||
Args:
|
||||
servers: Dict of server instances from start_mock_servers_async()
|
||||
"""
|
||||
for name, server in servers.items():
|
||||
try:
|
||||
if hasattr(server, 'get_request_count'):
|
||||
if hasattr(server, "get_request_count"):
|
||||
print(f"\n[INFO] {name} received {server.get_request_count()} requests")
|
||||
server.stop()
|
||||
except Exception as e:
|
||||
print(f"[WARN] Error stopping {name}: {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue