From 4938f2fe5da7ecd9fe7a5f51b7d95868ca149b99 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 9 Jan 2025 14:52:06 -0800 Subject: [PATCH] Check version incompatibility (#738) When we bump up `major.minor` we want to make sure clients can immediately detect a version change and appropriately error out. It is not reasonable to keep checking for API-level backwards compatibility across such version bumps. Over time, we will make the check based only on the major version perhaps. ### Test Plan Manually updated `__version__` in the client SDK to be "0.1.0" which is incompatible with server's current version "0.0.63", got the following error: image Without this update, the CLI worked correctly. --- llama_stack/distribution/server/server.py | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8c1e41dc0..1108d1049 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -16,6 +16,8 @@ import traceback import warnings from contextlib import asynccontextmanager + +from importlib.metadata import version as parse_version from pathlib import Path from typing import Any, Union @@ -228,6 +230,52 @@ class TracingMiddleware: await end_trace() +class ClientVersionMiddleware: + def __init__(self, app): + self.app = app + self.server_version = parse_version("llama-stack") + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + client_version = headers.get(b"x-llamastack-client-version", b"").decode() + if client_version: + try: + client_version_parts = tuple( + map(int, client_version.split(".")[:2]) + ) + server_version_parts = tuple( + map(int, self.server_version.split(".")[:2]) + ) + if client_version_parts != server_version_parts: + + async def send_version_error(send): + await send( + { + "type": "http.response.start", + "status": 426, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps( + { + "error": { + "message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client." + } + } + ).encode() + await send( + {"type": "http.response.body", "body": error_msg} + ) + + return await send_version_error(send) + except (ValueError, IndexError): + # If version parsing fails, let the request through + pass + + return await self.app(scope, receive, send) + + def main(): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") @@ -291,6 +339,7 @@ def main(): app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) + app.add_middleware(ClientVersionMiddleware) try: impls = asyncio.run(construct_stack(config))