From 436f8ade9e7be9ce5116c0976a0991b7e355f533 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 4 Apr 2025 15:53:21 -0400 Subject: [PATCH] feat: implement provider updating add `v1/providers/` which uses PUT to allow users to change their provider configuration this is a follow up to #1429 and related to #1359 a user can call something like: `llama_stack_client.providers.update(api="inference", provider_id="ollama", provider_type="remote::ollama", config={'url': 'http:/localhost:12345'})` or `llama-stack-client providers update inference ollama remote::ollama "{'url': 'http://localhost:12345'}"` this API works by adding a `RequestMiddleware` to the server which checks requests, and if the user is using PUT /v1/providers, the routes are re-registered with the re-initialized provider configurations/methods for the client, `self.impls` is updated to hold the proper methods+configurations this depends on a client PR, the CI will fail until then but succeeded locally Signed-off-by: Charlie Doern --- docs/_static/llama-stack-spec.html | 103 +++++++++ docs/_static/llama-stack-spec.yaml | 61 +++++ llama_stack/apis/providers/providers.py | 5 + llama_stack/distribution/library_client.py | 17 ++ llama_stack/distribution/providers.py | 86 ++++++- llama_stack/distribution/server/routes.py | 2 + llama_stack/distribution/server/server.py | 214 +++++++++++++----- tests/integration/providers/test_providers.py | 17 ++ 8 files changed, 449 insertions(+), 56 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ae9ad5d4c..eb31e634a 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4946,6 +4946,74 @@ } } }, + "/v1/providers/{api}/{provider_id}/{provider_type}": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderInfo" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Providers" + ], + "description": "", + "parameters": [ + { + "name": "api", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "provider_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "provider_type", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateProviderRequest" + } + } + }, + "required": true + } + } + }, "/v1/version": { "get": { "responses": { @@ -16101,6 +16169,41 @@ "title": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." }, + "UpdateProviderRequest": { + "type": "object", + "properties": { + "config": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "config" + ], + "title": "UpdateProviderRequest" + }, "VersionInfo": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 48cefe12b..3c4e9c865 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3484,6 +3484,50 @@ paths: schema: $ref: '#/components/schemas/SyntheticDataGenerateRequest' required: true + /v1/providers/{api}/{provider_id}/{provider_type}: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ProviderInfo' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Providers + description: '' + parameters: + - name: api + in: path + required: true + schema: + type: string + - name: provider_id + in: path + required: true + schema: + type: string + - name: provider_type + in: path + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateProviderRequest' + required: true /v1/version: get: responses: @@ -11234,6 +11278,23 @@ components: description: >- Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. + UpdateProviderRequest: + type: object + properties: + config: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - config + title: UpdateProviderRequest VersionInfo: type: object properties: diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 4bc977bf1..0f29a6e05 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -47,3 +47,8 @@ class Providers(Protocol): :returns: A ProviderInfo object containing the provider's details. """ ... + + @webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT") + async def update_provider( + self, api: str, provider_id: str, provider_type: str, config: dict[str, Any] + ) -> ProviderInfo: ... diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index cebfabba5..5a59706ac 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -25,6 +25,7 @@ from llama_stack_client import ( AsyncStream, LlamaStackClient, ) +from llama_stack_client.types import provider_info from pydantic import BaseModel, TypeAdapter from rich.console import Console from termcolor import cprint @@ -293,6 +294,22 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to=cast_to, options=options, ) + # Check if response is of a certain type + # this indicates we have done a provider update + if ( + isinstance(response, provider_info.ProviderInfo) + and hasattr(response, "config") + and options.method.lower() == "put" + ): + # patch in the new provider config + for api, providers in self.config.providers.items(): + if api != response.api: + continue + for prov in providers: + if prov.provider_id == response.provider_id: + prov.config = response.config + break + await self.initialize() return response async def _call_non_streaming( diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 1d9c1f4e9..3fc61df73 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import copy from typing import Any from pydantic import BaseModel @@ -13,7 +14,7 @@ from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Prov from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus -from .datatypes import StackRunConfig +from .datatypes import Provider, StackRunConfig from .utils.config import redact_sensitive_fields logger = get_logger(name=__name__, category="core") @@ -129,3 +130,86 @@ class ProviderImpl(Providers): providers_health[api_name] = health_response return providers_health + + async def update_provider( + self, api: str, provider_id: str, provider_type: str, config: dict[str, Any] + ) -> ProviderInfo: + # config = ast.literal_eval(provider_request.config) + prov = Provider( + provider_id=provider_id, + provider_type=provider_type, + config=config, + ) + assert prov.provider_id is not None + existing_provider = None + # if the provider isn't there or the API is invalid, we should not continue + for prov_api, providers in self.config.run_config.providers.items(): + if prov_api != api: + continue + for p in providers: + if p.provider_id == provider_id: + existing_provider = p + break + if existing_provider is not None: + break + + if existing_provider is None: + raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.") + + new_config = self.merge_providers(existing_provider, prov) + existing_provider.config = new_config + providers_health = await self.get_providers_health() + # takes a single provider, validates its in the registry + # if it is, merge the provider config with the existing one + ret = ProviderInfo( + api=api, + provider_id=prov.provider_id, + provider_type=prov.provider_type, + config=new_config, + health=providers_health.get(api, {}).get( + p.provider_id, + HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"), + ), + ) + + return ret + + def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]: + """Recursively merges `overrides` into `base`, replacing only specified keys.""" + + merged = copy.deepcopy(base) # Preserve original dict + for key, value in overrides.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + # Recursively merge if both are dictionaries + merged[key] = self.merge_dicts(merged[key], value) + else: + # Otherwise, directly override + merged[key] = value + + return merged + + def merge_configs( + self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]] + ) -> dict[str, list[Provider]]: + merged_config = copy.deepcopy(global_config) # Preserve original structure + + for key, new_providers in new_config.items(): + if key in merged_config: + existing_providers = {p.provider_id: p for p in merged_config[key]} + + for new_provider in new_providers: + if new_provider.provider_id in existing_providers: + # Override settings of existing provider + existing = existing_providers[new_provider.provider_id] + existing.config = self.merge_dicts(existing.config, new_provider.config) + else: + # Append new provider + merged_config[key].append(new_provider) + else: + # Add new category entirely + merged_config[key] = new_providers + + return merged_config + + def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]: + return self.merge_dicts(current_provider.config, new_provider.config) diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py index ea66fec5a..ee059a4a6 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/routes.py @@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]: http_method = hdrs.METH_GET elif webmethod.method == hdrs.METH_DELETE: http_method = hdrs.METH_DELETE + elif webmethod.method == hdrs.METH_PUT: + http_method = hdrs.METH_PUT else: http_method = hdrs.METH_POST routes.append( diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 83407a25f..7a1df8fa1 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.routing import APIRoute from openai import BadRequestError from pydantic import BaseModel, ValidationError +from starlette.types import Message from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig @@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: return route_handler +class RequestMiddleware: + def __init__(self, app, api, stack_run_config): + self.app = app + self.api = api + self.stack_run_config = stack_run_config + + async def __call__(self, scope, receive, send): + import json + + from fastapi import Request + + from llama_stack.apis.providers import ProviderInfo + + # from llama_stack.stack_utils import construct_stack # or wherever you define it + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope, receive) + method = request.method + path = request.url.path + + # Only intercept PUT /v1/providers/update + if method == "PUT" and "/v1/providers" in path: + # Clone the request body so FastAPI doesn't break + + body = await request.body() + request = Request(scope, receive_from_body(body)) + + response_body = b"" + status_code = 500 + headers = [] + + async def send_wrapper(message: Message): + nonlocal response_body, status_code, headers + + if message["type"] == "http.response.start": + status_code = message["status"] + headers = message.get("headers", []) + + elif message["type"] == "http.response.body": + response_body += message.get("body", b"") + + if not message.get("more_body", False): + # Rebuild stack + try: + # Parse the request body (not response) + payload = json.loads(response_body.decode("utf-8")) + new_provider = ProviderInfo(**payload) + for api, providers in self.stack_run_config.providers.items(): + if api != new_provider.api: + continue + for prov in providers: + if prov.provider_id == new_provider.provider_id: + prov.config = new_provider.config + break + + _, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True) + self.api.__llama_stack_impls__ = impls + print("✅ Stack rebuilt and updated.") + except Exception as e: + print(f"⚠️ Failed to rebuild stack: {e}") + + await send(message) + + return await self.app(scope, request.receive, send_wrapper) + + # All other requests go through normally + return await self.app(scope, receive, send) + + +# Helper to inject the saved body back into the request +def receive_from_body(body: bytes): + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + return receive + + class TracingMiddleware: def __init__(self, app, impls): self.app = app @@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) - try: - impls = asyncio.run(construct_stack(config)) - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) - - if Api.telemetry in impls: - setup_logger(impls[Api.telemetry]) - else: - setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - - all_routes = get_all_api_routes() - - if config.apis: - apis_to_serve = set(config.apis) - else: - apis_to_serve = set(impls.keys()) - - for inf in builtin_automatically_routed_apis(): - # if we do not serve the corresponding router API, we should not serve the routing table API - if inf.router_api.value not in apis_to_serve: - continue - apis_to_serve.add(inf.routing_table_api.value) - - apis_to_serve.add("inspect") - apis_to_serve.add("providers") - for api_str in apis_to_serve: - api = Api(api_str) - - routes = all_routes[api] - impl = impls[api] - - for route in routes: - if not hasattr(impl, route.name): - # ideally this should be a typing violation already - raise ValueError(f"Could not find method {route.name} on {impl}!") - - impl_method = getattr(impl, route.name) - # Filter out HEAD method since it's automatically handled by FastAPI for GET routes - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - raise ValueError(f"No methods found for {route.name} on {impl}") - method = available_methods[0] - logger.debug(f"{method} {route.path}") - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, method.lower())(route.path, response_model=None)( - create_dynamic_typed_route( - impl_method, - method.lower(), - route.path, - ) - ) - + apis_to_serve, impls = asyncio.run(construct(app=app, config=config)) logger.debug(f"serving APIs: {apis_to_serve}") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) app.__llama_stack_impls__ = impls + # Add the custom middleware + app.add_middleware(RequestMiddleware, api=app, stack_run_config=config) app.add_middleware(TracingMiddleware, impls=impls) import uvicorn @@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]: return params +async def construct( + app: FastAPI, config: StackRunConfig, reconstruct: bool = False +) -> tuple[set[str] | set[Api], dict[Api, Any]]: + try: + impls = await construct_stack(config) + except InvalidProviderError as e: + logger.error(f"Error: {str(e)}") + sys.exit(1) + + if Api.telemetry in impls: + setup_logger(impls[Api.telemetry]) + else: + setup_logger(TelemetryAdapter(TelemetryConfig(), {})) + + all_routes = get_all_api_routes() + + if config.apis: + apis_to_serve = set(config.apis) + else: + apis_to_serve = set(impls.keys()) + + for inf in builtin_automatically_routed_apis(): + # if we do not serve the corresponding router API, we should not serve the routing table API + if inf.router_api.value not in apis_to_serve: + continue + apis_to_serve.add(inf.routing_table_api.value) + + apis_to_serve.add("inspect") + apis_to_serve.add("providers") + + for api_str in apis_to_serve: + api = Api(api_str) + + routes = all_routes[api] + impl = impls[api] + for route in routes: + if not hasattr(impl, route.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {route.name} on {impl}!!") + + impl_method = getattr(impl, route.name) + + # Filter out HEAD method since it's automatically handled by FastAPI for GET routes + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + raise ValueError(f"No methods found for {route.name} on {impl}") + method = available_methods[0] + logger.debug(f"{method} {route.path}") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") + # Remove old route + if reconstruct: + app.router.routes = [ + r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods) + ] + new_endpoint = create_dynamic_typed_route( + impl_method, + method.lower(), + route.path, # route.path + ) + getattr(app, method.lower())(route.path, response_model=None)(new_endpoint) + if reconstruct: + new_route = APIRoute( + response_model=None, + path=route.path, + endpoint=new_endpoint, + methods=[method.lower()], + name=impl_method.__name__, + ) + # Add new route + app.router.routes.append(new_route) + + return apis_to_serve, impls + + if __name__ == "__main__": main() diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py index 8b153411c..2069d4b61 100644 --- a/tests/integration/providers/test_providers.py +++ b/tests/integration/providers/test_providers.py @@ -21,3 +21,20 @@ class TestProviders: pid = provider.provider_id provider = llama_stack_client.providers.retrieve(pid) assert provider is not None + + @pytest.mark.asyncio + def test_providers_update(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): + new_cfg = {"url": "http://localhost:12345"} + + _ = llama_stack_client.providers.retrieve("ollama") + + llama_stack_client.providers.update( + api="inference", + provider_id="ollama", + provider_type="remote::ollama", + config=new_cfg, + ) + + new_provider = llama_stack_client.providers.retrieve("ollama") + + assert new_provider.config == new_cfg