feat: implement provider updating

add `v1/providers/` which uses PUT to allow users to change their provider configuration

this is a follow up to #1429 and related to #1359

a user can call something like:

`llama_stack_client.providers.update(api="inference", provider_id="ollama", provider_type="remote::ollama", config={'url': 'http:/localhost:12345'})`

or

`llama-stack-client providers update inference ollama remote::ollama "{'url': 'http://localhost:12345'}"`

this API works by adding a `RequestMiddleware` to the server which checks requests, and if the user is using PUT /v1/providers, the routes are re-registered with the re-initialized provider configurations/methods

for the client, `self.impls` is updated to hold the proper methods+configurations

this depends on a client PR, the CI will fail until then but succeeded locally

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-04-04 15:53:21 -04:00
parent d165000bbc
commit 436f8ade9e
8 changed files with 449 additions and 56 deletions

View file

@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from starlette.types import Message
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return route_handler
class RequestMiddleware:
def __init__(self, app, api, stack_run_config):
self.app = app
self.api = api
self.stack_run_config = stack_run_config
async def __call__(self, scope, receive, send):
import json
from fastapi import Request
from llama_stack.apis.providers import ProviderInfo
# from llama_stack.stack_utils import construct_stack # or wherever you define it
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
method = request.method
path = request.url.path
# Only intercept PUT /v1/providers/update
if method == "PUT" and "/v1/providers" in path:
# Clone the request body so FastAPI doesn't break
body = await request.body()
request = Request(scope, receive_from_body(body))
response_body = b""
status_code = 500
headers = []
async def send_wrapper(message: Message):
nonlocal response_body, status_code, headers
if message["type"] == "http.response.start":
status_code = message["status"]
headers = message.get("headers", [])
elif message["type"] == "http.response.body":
response_body += message.get("body", b"")
if not message.get("more_body", False):
# Rebuild stack
try:
# Parse the request body (not response)
payload = json.loads(response_body.decode("utf-8"))
new_provider = ProviderInfo(**payload)
for api, providers in self.stack_run_config.providers.items():
if api != new_provider.api:
continue
for prov in providers:
if prov.provider_id == new_provider.provider_id:
prov.config = new_provider.config
break
_, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True)
self.api.__llama_stack_impls__ = impls
print("✅ Stack rebuilt and updated.")
except Exception as e:
print(f"⚠️ Failed to rebuild stack: {e}")
await send(message)
return await self.app(scope, request.receive, send_wrapper)
# All other requests go through normally
return await self.app(scope, receive, send)
# Helper to inject the saved body back into the request
def receive_from_body(body: bytes):
async def receive() -> Message:
return {"type": "http.request", "body": body, "more_body": False}
return receive
class TracingMiddleware:
def __init__(self, app, impls):
self.app = app
@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
method.lower(),
route.path,
)
)
apis_to_serve, impls = asyncio.run(construct(app=app, config=config))
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
# Add the custom middleware
app.add_middleware(RequestMiddleware, api=app, stack_run_config=config)
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn
@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]:
return params
async def construct(
app: FastAPI, config: StackRunConfig, reconstruct: bool = False
) -> tuple[set[str] | set[Api], dict[Api, Any]]:
try:
impls = await construct_stack(config)
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
# Remove old route
if reconstruct:
app.router.routes = [
r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods)
]
new_endpoint = create_dynamic_typed_route(
impl_method,
method.lower(),
route.path, # route.path
)
getattr(app, method.lower())(route.path, response_model=None)(new_endpoint)
if reconstruct:
new_route = APIRoute(
response_model=None,
path=route.path,
endpoint=new_endpoint,
methods=[method.lower()],
name=impl_method.__name__,
)
# Add new route
app.router.routes.append(new_route)
return apis_to_serve, impls
if __name__ == "__main__":
main()