Check version incompatibility (#738)

When we bump up `major.minor` we want to make sure clients can
immediately detect a version change and appropriately error out. It is
not reasonable to keep checking for API-level backwards compatibility
across such version bumps. Over time, we will make the check based only
on the major version perhaps.

### Test Plan

Manually updated `__version__` in the client SDK to be "0.1.0" which is
incompatible with server's current version "0.0.63", got the following
error:

<img width="1077" alt="image"
src="https://github.com/user-attachments/assets/06ae4659-0a25-4c4c-a999-ce44678d4e6f"
/>

Without this update, the CLI worked correctly.
This commit is contained in:
Ashwin Bharambe 2025-01-09 14:52:06 -08:00 committed by GitHub
parent ffc6bd4805
commit 4938f2fe5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -16,6 +16,8 @@ import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
@ -228,6 +230,52 @@ class TracingMiddleware:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -291,6 +339,7 @@ def main():
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))