more fixes, plug shutdown handlers

still, FastAPIs sigint handler is not calling ours
This commit is contained in:
Ashwin Bharambe 2024-10-05 23:48:18 -07:00 committed by Ashwin Bharambe
parent 60dead6196
commit e45a417543
4 changed files with 32 additions and 12 deletions

View file

@ -42,7 +42,8 @@ class CommonRoutingTableImpl(RoutingTable):
await self.register_object(obj, p) await self.register_object(obj, p)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any: def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object: if routing_key not in self.routing_key_to_object:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import signal import signal
@ -169,11 +170,20 @@ async def passthrough(
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs): def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...") print("SIGINT or CTRL-C detected. Exiting gracefully...")
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
asyncio.run(run_shutdown())
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop): for task in asyncio.all_tasks(loop):
task.cancel() task.cancel()
loop.stop() loop.stop()
@ -181,7 +191,10 @@ def handle_sigint(*args, **kwargs):
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
print("Starting up") print("Starting up")
yield yield
print("Shutting down") print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def create_dynamic_passthrough( def create_dynamic_passthrough(
@ -333,7 +346,9 @@ def main(
print("") print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
app.__llama_stack_impls__ = impls
import uvicorn import uvicorn

View file

@ -13,8 +13,6 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
@ -25,7 +23,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _HfAdapter(Inference, RoutableProvider): class _HfAdapter(Inference):
client: AsyncInferenceClient client: AsyncInferenceClient
max_tokens: int max_tokens: int
model_id: str model_id: str
@ -34,11 +32,17 @@ class _HfAdapter(Inference, RoutableProvider):
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
async def validate_routing_keys(self, routing_keys: list[str]) -> None: # TODO: make this work properly by checking this against the model_id being
# these are the model names the Llama Stack will use to route requests to this provider # served by the remote endpoint
# perform validation here if necessary async def register_model(self, model: ModelDef) -> None:
pass pass
async def list_models(self) -> List[ModelDef]:
return []
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return None
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -42,7 +42,7 @@ from llama_stack.apis.inference.inference import (
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import VLLMConfig from .config import VLLMConfig
@ -75,7 +75,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
return SamplingParams().from_optional(**kwargs) return SamplingParams().from_optional(**kwargs)
class VLLMInferenceImpl(Inference, RoutableProviderForModels): class VLLMInferenceImpl(Inference, ModelRegistryHelper):
"""Inference implementation for vLLM.""" """Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = { HF_MODEL_MAPPINGS = {
@ -109,7 +109,7 @@ class VLLMInferenceImpl(Inference, RoutableProviderForModels):
def __init__(self, config: VLLMConfig): def __init__(self, config: VLLMConfig):
Inference.__init__(self) Inference.__init__(self)
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, self,
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS, stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
) )