mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Check version incompatibility (#738)
When we bump up `major.minor` we want to make sure clients can immediately detect a version change and appropriately error out. It is not reasonable to keep checking for API-level backwards compatibility across such version bumps. Over time, we will make the check based only on the major version perhaps. ### Test Plan Manually updated `__version__` in the client SDK to be "0.1.0" which is incompatible with server's current version "0.0.63", got the following error: <img width="1077" alt="image" src="https://github.com/user-attachments/assets/06ae4659-0a25-4c4c-a999-ce44678d4e6f" /> Without this update, the CLI worked correctly.
This commit is contained in:
parent
ffc6bd4805
commit
4938f2fe5d
1 changed files with 49 additions and 0 deletions
|
@ -16,6 +16,8 @@ import traceback
|
|||
import warnings
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
|
@ -228,6 +230,52 @@ class TracingMiddleware:
|
|||
await end_trace()
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.server_version = parse_version("llama-stack")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||
if client_version:
|
||||
try:
|
||||
client_version_parts = tuple(
|
||||
map(int, client_version.split(".")[:2])
|
||||
)
|
||||
server_version_parts = tuple(
|
||||
map(int, self.server_version.split(".")[:2])
|
||||
)
|
||||
if client_version_parts != server_version_parts:
|
||||
|
||||
async def send_version_error(send):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 426,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
await send(
|
||||
{"type": "http.response.body", "body": error_msg}
|
||||
)
|
||||
|
||||
return await send_version_error(send)
|
||||
except (ValueError, IndexError):
|
||||
# If version parsing fails, let the request through
|
||||
pass
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main():
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
@ -291,6 +339,7 @@ def main():
|
|||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue