mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
more fixes, plug shutdown handlers
still, FastAPIs sigint handler is not calling ours
This commit is contained in:
parent
60dead6196
commit
e45a417543
4 changed files with 32 additions and 12 deletions
|
@ -42,7 +42,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await self.register_object(obj, p)
|
await self.register_object(obj, p)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
for p in self.impls_by_provider_id.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
if routing_key not in self.routing_key_to_object:
|
if routing_key not in self.routing_key_to_object:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
|
@ -169,11 +170,20 @@ async def passthrough(
|
||||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(*args, **kwargs):
|
def handle_sigint(app, *args, **kwargs):
|
||||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||||
|
|
||||||
|
async def run_shutdown():
|
||||||
|
for impl in app.__llama_stack_impls__.values():
|
||||||
|
print(f"Shutting down {impl}")
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
asyncio.run(run_shutdown())
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for task in asyncio.all_tasks(loop):
|
for task in asyncio.all_tasks(loop):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +191,10 @@ def handle_sigint(*args, **kwargs):
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
print("Starting up")
|
print("Starting up")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
print("Shutting down")
|
print("Shutting down")
|
||||||
|
for impl in app.__llama_stack_impls__.values():
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_passthrough(
|
def create_dynamic_passthrough(
|
||||||
|
@ -333,7 +346,9 @@ def main(
|
||||||
print("")
|
print("")
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||||
|
|
||||||
|
app.__llama_stack_impls__ = impls
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,6 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
|
@ -25,7 +23,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference, RoutableProvider):
|
class _HfAdapter(Inference):
|
||||||
client: AsyncInferenceClient
|
client: AsyncInferenceClient
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
model_id: str
|
model_id: str
|
||||||
|
@ -34,11 +32,17 @@ class _HfAdapter(Inference, RoutableProvider):
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
# TODO: make this work properly by checking this against the model_id being
|
||||||
# these are the model names the Llama Stack will use to route requests to this provider
|
# served by the remote endpoint
|
||||||
# perform validation here if necessary
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
return None
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
||||||
return SamplingParams().from_optional(**kwargs)
|
return SamplingParams().from_optional(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceImpl(Inference, RoutableProviderForModels):
|
class VLLMInferenceImpl(Inference, ModelRegistryHelper):
|
||||||
"""Inference implementation for vLLM."""
|
"""Inference implementation for vLLM."""
|
||||||
|
|
||||||
HF_MODEL_MAPPINGS = {
|
HF_MODEL_MAPPINGS = {
|
||||||
|
@ -109,7 +109,7 @@ class VLLMInferenceImpl(Inference, RoutableProviderForModels):
|
||||||
|
|
||||||
def __init__(self, config: VLLMConfig):
|
def __init__(self, config: VLLMConfig):
|
||||||
Inference.__init__(self)
|
Inference.__init__(self)
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self,
|
self,
|
||||||
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue