diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8c1e41dc0..1108d1049 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -16,6 +16,8 @@ import traceback import warnings from contextlib import asynccontextmanager + +from importlib.metadata import version as parse_version from pathlib import Path from typing import Any, Union @@ -228,6 +230,52 @@ class TracingMiddleware: await end_trace() +class ClientVersionMiddleware: + def __init__(self, app): + self.app = app + self.server_version = parse_version("llama-stack") + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + client_version = headers.get(b"x-llamastack-client-version", b"").decode() + if client_version: + try: + client_version_parts = tuple( + map(int, client_version.split(".")[:2]) + ) + server_version_parts = tuple( + map(int, self.server_version.split(".")[:2]) + ) + if client_version_parts != server_version_parts: + + async def send_version_error(send): + await send( + { + "type": "http.response.start", + "status": 426, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps( + { + "error": { + "message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client." + } + } + ).encode() + await send( + {"type": "http.response.body", "body": error_msg} + ) + + return await send_version_error(send) + except (ValueError, IndexError): + # If version parsing fails, let the request through + pass + + return await self.app(scope, receive, send) + + def main(): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") @@ -291,6 +339,7 @@ def main(): app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) + app.add_middleware(ClientVersionMiddleware) try: impls = asyncio.run(construct_stack(config))