Check version incompatibility

This commit is contained in:
Ashwin Bharambe 2025-01-09 12:18:37 -08:00
parent ffc6bd4805
commit e09d5efe87

View file

@ -16,6 +16,8 @@ import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
@ -228,6 +230,52 @@ class TracingMiddleware:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -291,6 +339,7 @@ def main():
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))