diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index ae9ad5d4c..eb31e634a 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4946,6 +4946,74 @@
}
}
},
+ "/v1/providers/{api}/{provider_id}/{provider_type}": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ProviderInfo"
+ }
+ }
+ }
+ },
+ "400": {
+ "$ref": "#/components/responses/BadRequest400"
+ },
+ "429": {
+ "$ref": "#/components/responses/TooManyRequests429"
+ },
+ "500": {
+ "$ref": "#/components/responses/InternalServerError500"
+ },
+ "default": {
+ "$ref": "#/components/responses/DefaultError"
+ }
+ },
+ "tags": [
+ "Providers"
+ ],
+ "description": "",
+ "parameters": [
+ {
+ "name": "api",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "provider_id",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "provider_type",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/UpdateProviderRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
"/v1/version": {
"get": {
"responses": {
@@ -16101,6 +16169,41 @@
"title": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
},
+ "UpdateProviderRequest": {
+ "type": "object",
+ "properties": {
+ "config": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "config"
+ ],
+ "title": "UpdateProviderRequest"
+ },
"VersionInfo": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 48cefe12b..3c4e9c865 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -3484,6 +3484,50 @@ paths:
schema:
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
required: true
+ /v1/providers/{api}/{provider_id}/{provider_type}:
+ post:
+ responses:
+ '200':
+ description: OK
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ProviderInfo'
+ '400':
+ $ref: '#/components/responses/BadRequest400'
+ '429':
+ $ref: >-
+ #/components/responses/TooManyRequests429
+ '500':
+ $ref: >-
+ #/components/responses/InternalServerError500
+ default:
+ $ref: '#/components/responses/DefaultError'
+ tags:
+ - Providers
+ description: ''
+ parameters:
+ - name: api
+ in: path
+ required: true
+ schema:
+ type: string
+ - name: provider_id
+ in: path
+ required: true
+ schema:
+ type: string
+ - name: provider_type
+ in: path
+ required: true
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/UpdateProviderRequest'
+ required: true
/v1/version:
get:
responses:
@@ -11234,6 +11278,23 @@ components:
description: >-
Response from the synthetic data generation. Batch of (prompt, response, score)
tuples that pass the threshold.
+ UpdateProviderRequest:
+ type: object
+ properties:
+ config:
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ additionalProperties: false
+ required:
+ - config
+ title: UpdateProviderRequest
VersionInfo:
type: object
properties:
diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py
index 4bc977bf1..0f29a6e05 100644
--- a/llama_stack/apis/providers/providers.py
+++ b/llama_stack/apis/providers/providers.py
@@ -47,3 +47,8 @@ class Providers(Protocol):
:returns: A ProviderInfo object containing the provider's details.
"""
...
+
+ @webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT")
+ async def update_provider(
+ self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
+ ) -> ProviderInfo: ...
diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py
index cebfabba5..5a59706ac 100644
--- a/llama_stack/distribution/library_client.py
+++ b/llama_stack/distribution/library_client.py
@@ -25,6 +25,7 @@ from llama_stack_client import (
AsyncStream,
LlamaStackClient,
)
+from llama_stack_client.types import provider_info
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
@@ -293,6 +294,22 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to=cast_to,
options=options,
)
+ # Check if response is of a certain type
+ # this indicates we have done a provider update
+ if (
+ isinstance(response, provider_info.ProviderInfo)
+ and hasattr(response, "config")
+ and options.method.lower() == "put"
+ ):
+ # patch in the new provider config
+ for api, providers in self.config.providers.items():
+ if api != response.api:
+ continue
+ for prov in providers:
+ if prov.provider_id == response.provider_id:
+ prov.config = response.config
+ break
+ await self.initialize()
return response
async def _call_non_streaming(
diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py
index 1d9c1f4e9..3fc61df73 100644
--- a/llama_stack/distribution/providers.py
+++ b/llama_stack/distribution/providers.py
@@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
+import copy
from typing import Any
from pydantic import BaseModel
@@ -13,7 +14,7 @@ from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Prov
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
-from .datatypes import StackRunConfig
+from .datatypes import Provider, StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
@@ -129,3 +130,86 @@ class ProviderImpl(Providers):
providers_health[api_name] = health_response
return providers_health
+
+ async def update_provider(
+ self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
+ ) -> ProviderInfo:
+ # config = ast.literal_eval(provider_request.config)
+ prov = Provider(
+ provider_id=provider_id,
+ provider_type=provider_type,
+ config=config,
+ )
+ assert prov.provider_id is not None
+ existing_provider = None
+ # if the provider isn't there or the API is invalid, we should not continue
+ for prov_api, providers in self.config.run_config.providers.items():
+ if prov_api != api:
+ continue
+ for p in providers:
+ if p.provider_id == provider_id:
+ existing_provider = p
+ break
+ if existing_provider is not None:
+ break
+
+ if existing_provider is None:
+ raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")
+
+ new_config = self.merge_providers(existing_provider, prov)
+ existing_provider.config = new_config
+ providers_health = await self.get_providers_health()
+ # takes a single provider, validates its in the registry
+ # if it is, merge the provider config with the existing one
+ ret = ProviderInfo(
+ api=api,
+ provider_id=prov.provider_id,
+ provider_type=prov.provider_type,
+ config=new_config,
+ health=providers_health.get(api, {}).get(
+ p.provider_id,
+ HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
+ ),
+ )
+
+ return ret
+
+ def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
+ """Recursively merges `overrides` into `base`, replacing only specified keys."""
+
+ merged = copy.deepcopy(base) # Preserve original dict
+ for key, value in overrides.items():
+ if isinstance(value, dict) and isinstance(merged.get(key), dict):
+ # Recursively merge if both are dictionaries
+ merged[key] = self.merge_dicts(merged[key], value)
+ else:
+ # Otherwise, directly override
+ merged[key] = value
+
+ return merged
+
+ def merge_configs(
+ self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
+ ) -> dict[str, list[Provider]]:
+ merged_config = copy.deepcopy(global_config) # Preserve original structure
+
+ for key, new_providers in new_config.items():
+ if key in merged_config:
+ existing_providers = {p.provider_id: p for p in merged_config[key]}
+
+ for new_provider in new_providers:
+ if new_provider.provider_id in existing_providers:
+ # Override settings of existing provider
+ existing = existing_providers[new_provider.provider_id]
+ existing.config = self.merge_dicts(existing.config, new_provider.config)
+ else:
+ # Append new provider
+ merged_config[key].append(new_provider)
+ else:
+ # Add new category entirely
+ merged_config[key] = new_providers
+
+ return merged_config
+
+ def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
+ return self.merge_dicts(current_provider.config, new_provider.config)
diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py
index ea66fec5a..ee059a4a6 100644
--- a/llama_stack/distribution/server/routes.py
+++ b/llama_stack/distribution/server/routes.py
@@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
+ elif webmethod.method == hdrs.METH_PUT:
+ http_method = hdrs.METH_PUT
else:
http_method = hdrs.METH_POST
routes.append(
diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py
index 83407a25f..7a1df8fa1 100644
--- a/llama_stack/distribution/server/server.py
+++ b/llama_stack/distribution/server/server.py
@@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
+from fastapi.routing import APIRoute
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
+from starlette.types import Message
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
@@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return route_handler
+class RequestMiddleware:
+ def __init__(self, app, api, stack_run_config):
+ self.app = app
+ self.api = api
+ self.stack_run_config = stack_run_config
+
+ async def __call__(self, scope, receive, send):
+ import json
+
+ from fastapi import Request
+
+ from llama_stack.apis.providers import ProviderInfo
+
+ # from llama_stack.stack_utils import construct_stack # or wherever you define it
+ if scope["type"] != "http":
+ return await self.app(scope, receive, send)
+
+ request = Request(scope, receive)
+ method = request.method
+ path = request.url.path
+
+ # Only intercept PUT /v1/providers/update
+ if method == "PUT" and "/v1/providers" in path:
+ # Clone the request body so FastAPI doesn't break
+
+ body = await request.body()
+ request = Request(scope, receive_from_body(body))
+
+ response_body = b""
+ status_code = 500
+ headers = []
+
+ async def send_wrapper(message: Message):
+ nonlocal response_body, status_code, headers
+
+ if message["type"] == "http.response.start":
+ status_code = message["status"]
+ headers = message.get("headers", [])
+
+ elif message["type"] == "http.response.body":
+ response_body += message.get("body", b"")
+
+ if not message.get("more_body", False):
+ # Rebuild stack
+ try:
+ # Parse the request body (not response)
+ payload = json.loads(response_body.decode("utf-8"))
+ new_provider = ProviderInfo(**payload)
+ for api, providers in self.stack_run_config.providers.items():
+ if api != new_provider.api:
+ continue
+ for prov in providers:
+ if prov.provider_id == new_provider.provider_id:
+ prov.config = new_provider.config
+ break
+
+ _, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True)
+ self.api.__llama_stack_impls__ = impls
+ print("✅ Stack rebuilt and updated.")
+ except Exception as e:
+ print(f"⚠️ Failed to rebuild stack: {e}")
+
+ await send(message)
+
+ return await self.app(scope, request.receive, send_wrapper)
+
+ # All other requests go through normally
+ return await self.app(scope, receive, send)
+
+
+# Helper to inject the saved body back into the request
+def receive_from_body(body: bytes):
+ async def receive() -> Message:
+ return {"type": "http.request", "body": body, "more_body": False}
+
+ return receive
+
+
class TracingMiddleware:
def __init__(self, app, impls):
self.app = app
@@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
- try:
- impls = asyncio.run(construct_stack(config))
- except InvalidProviderError as e:
- logger.error(f"Error: {str(e)}")
- sys.exit(1)
-
- if Api.telemetry in impls:
- setup_logger(impls[Api.telemetry])
- else:
- setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
-
- all_routes = get_all_api_routes()
-
- if config.apis:
- apis_to_serve = set(config.apis)
- else:
- apis_to_serve = set(impls.keys())
-
- for inf in builtin_automatically_routed_apis():
- # if we do not serve the corresponding router API, we should not serve the routing table API
- if inf.router_api.value not in apis_to_serve:
- continue
- apis_to_serve.add(inf.routing_table_api.value)
-
- apis_to_serve.add("inspect")
- apis_to_serve.add("providers")
- for api_str in apis_to_serve:
- api = Api(api_str)
-
- routes = all_routes[api]
- impl = impls[api]
-
- for route in routes:
- if not hasattr(impl, route.name):
- # ideally this should be a typing violation already
- raise ValueError(f"Could not find method {route.name} on {impl}!")
-
- impl_method = getattr(impl, route.name)
- # Filter out HEAD method since it's automatically handled by FastAPI for GET routes
- available_methods = [m for m in route.methods if m != "HEAD"]
- if not available_methods:
- raise ValueError(f"No methods found for {route.name} on {impl}")
- method = available_methods[0]
- logger.debug(f"{method} {route.path}")
-
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
- getattr(app, method.lower())(route.path, response_model=None)(
- create_dynamic_typed_route(
- impl_method,
- method.lower(),
- route.path,
- )
- )
-
+ apis_to_serve, impls = asyncio.run(construct(app=app, config=config))
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
+ # Add the custom middleware
+ app.add_middleware(RequestMiddleware, api=app, stack_run_config=config)
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn
@@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]:
return params
+async def construct(
+ app: FastAPI, config: StackRunConfig, reconstruct: bool = False
+) -> tuple[set[str] | set[Api], dict[Api, Any]]:
+ try:
+ impls = await construct_stack(config)
+ except InvalidProviderError as e:
+ logger.error(f"Error: {str(e)}")
+ sys.exit(1)
+
+ if Api.telemetry in impls:
+ setup_logger(impls[Api.telemetry])
+ else:
+ setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
+
+ all_routes = get_all_api_routes()
+
+ if config.apis:
+ apis_to_serve = set(config.apis)
+ else:
+ apis_to_serve = set(impls.keys())
+
+ for inf in builtin_automatically_routed_apis():
+ # if we do not serve the corresponding router API, we should not serve the routing table API
+ if inf.router_api.value not in apis_to_serve:
+ continue
+ apis_to_serve.add(inf.routing_table_api.value)
+
+ apis_to_serve.add("inspect")
+ apis_to_serve.add("providers")
+
+ for api_str in apis_to_serve:
+ api = Api(api_str)
+
+ routes = all_routes[api]
+ impl = impls[api]
+ for route in routes:
+ if not hasattr(impl, route.name):
+ # ideally this should be a typing violation already
+ raise ValueError(f"Could not find method {route.name} on {impl}!!")
+
+ impl_method = getattr(impl, route.name)
+
+ # Filter out HEAD method since it's automatically handled by FastAPI for GET routes
+ available_methods = [m for m in route.methods if m != "HEAD"]
+ if not available_methods:
+ raise ValueError(f"No methods found for {route.name} on {impl}")
+ method = available_methods[0]
+ logger.debug(f"{method} {route.path}")
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
+ # Remove old route
+ if reconstruct:
+ app.router.routes = [
+ r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods)
+ ]
+ new_endpoint = create_dynamic_typed_route(
+ impl_method,
+ method.lower(),
+ route.path, # route.path
+ )
+ getattr(app, method.lower())(route.path, response_model=None)(new_endpoint)
+ if reconstruct:
+ new_route = APIRoute(
+ response_model=None,
+ path=route.path,
+ endpoint=new_endpoint,
+ methods=[method.lower()],
+ name=impl_method.__name__,
+ )
+ # Add new route
+ app.router.routes.append(new_route)
+
+ return apis_to_serve, impls
+
+
if __name__ == "__main__":
main()
diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py
index 8b153411c..2069d4b61 100644
--- a/tests/integration/providers/test_providers.py
+++ b/tests/integration/providers/test_providers.py
@@ -21,3 +21,20 @@ class TestProviders:
pid = provider.provider_id
provider = llama_stack_client.providers.retrieve(pid)
assert provider is not None
+
+ @pytest.mark.asyncio
+ def test_providers_update(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
+ new_cfg = {"url": "http://localhost:12345"}
+
+ _ = llama_stack_client.providers.retrieve("ollama")
+
+ llama_stack_client.providers.update(
+ api="inference",
+ provider_id="ollama",
+ provider_type="remote::ollama",
+ config=new_cfg,
+ )
+
+ new_provider = llama_stack_client.providers.retrieve("ollama")
+
+ assert new_provider.config == new_cfg