diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 350ab05fe..e51534446 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -42,7 +42,8 @@ class CommonRoutingTableImpl(RoutingTable): await self.register_object(obj, p) async def shutdown(self) -> None: - pass + for p in self.impls_by_provider_id.values(): + await p.shutdown() def get_provider_impl(self, routing_key: str) -> Any: if routing_key not in self.routing_key_to_object: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ed3b4b9f2..dd3fafd0a 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import functools import inspect import json import signal @@ -169,11 +170,20 @@ async def passthrough( await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) -def handle_sigint(*args, **kwargs): +def handle_sigint(app, *args, **kwargs): print("SIGINT or CTRL-C detected. Exiting gracefully...") + + async def run_shutdown(): + for impl in app.__llama_stack_impls__.values(): + print(f"Shutting down {impl}") + await impl.shutdown() + + asyncio.run(run_shutdown()) + loop = asyncio.get_event_loop() for task in asyncio.all_tasks(loop): task.cancel() + loop.stop() @@ -181,7 +191,10 @@ def handle_sigint(*args, **kwargs): async def lifespan(app: FastAPI): print("Starting up") yield + print("Shutting down") + for impl in app.__llama_stack_impls__.values(): + await impl.shutdown() def create_dynamic_passthrough( @@ -333,7 +346,9 @@ def main( print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - signal.signal(signal.SIGINT, handle_sigint) + signal.signal(signal.SIGINT, functools.partial(handle_sigint, app)) + + app.__llama_stack_impls__ = impls import uvicorn diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index a5e5a99be..9868a9364 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -13,8 +13,6 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.distribution.datatypes import RoutableProvider - from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, @@ -25,7 +23,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl logger = logging.getLogger(__name__) -class _HfAdapter(Inference, RoutableProvider): +class _HfAdapter(Inference): client: AsyncInferenceClient max_tokens: int model_id: str @@ -34,11 +32,17 @@ class _HfAdapter(Inference, RoutableProvider): self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - async def validate_routing_keys(self, routing_keys: list[str]) -> None: - # these are the model names the Llama Stack will use to route requests to this provider - # perform validation here if necessary + # TODO: make this work properly by checking this against the model_id being + # served by the remote endpoint + async def register_model(self, model: ModelDef) -> None: pass + async def list_models(self) -> List[ModelDef]: + return [] + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + return None + async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index ecaa6bc45..0f8e8d38c 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -42,7 +42,7 @@ from llama_stack.apis.inference.inference import ( from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from llama_stack.providers.utils.inference.routable import RoutableProviderForModels +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .config import VLLMConfig @@ -75,7 +75,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: return SamplingParams().from_optional(**kwargs) -class VLLMInferenceImpl(Inference, RoutableProviderForModels): +class VLLMInferenceImpl(Inference, ModelRegistryHelper): """Inference implementation for vLLM.""" HF_MODEL_MAPPINGS = { @@ -109,7 +109,7 @@ class VLLMInferenceImpl(Inference, RoutableProviderForModels): def __init__(self, config: VLLMConfig): Inference.__init__(self) - RoutableProviderForModels.__init__( + ModelRegistryHelper.__init__( self, stack_to_provider_models_map=self.HF_MODEL_MAPPINGS, )