feat: implement provider updating

add `v1/providers/` which uses PUT to allow users to change their provider configuration

this is a follow up to #1429 and related to #1359

a user can call something like:

`llama_stack_client.providers.update(api="inference", provider_id="ollama", provider_type="remote::ollama", config={'url': 'http:/localhost:12345'})`

or

`llama-stack-client providers update inference ollama remote::ollama "{'url': 'http://localhost:12345'}"`

this API works by adding a `RequestMiddleware` to the server which checks requests, and if the user is using PUT /v1/providers, the routes are re-registered with the re-initialized provider configurations/methods

for the client, `self.impls` is updated to hold the proper methods+configurations

this depends on a client PR, the CI will fail until then but succeeded locally

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-04-04 15:53:21 -04:00
parent d165000bbc
commit 436f8ade9e
8 changed files with 449 additions and 56 deletions

View file

@ -4946,6 +4946,74 @@
} }
} }
}, },
"/v1/providers/{api}/{provider_id}/{provider_type}": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Providers"
],
"description": "",
"parameters": [
{
"name": "api",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "provider_id",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "provider_type",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UpdateProviderRequest"
}
}
},
"required": true
}
}
},
"/v1/version": { "/v1/version": {
"get": { "get": {
"responses": { "responses": {
@ -16101,6 +16169,41 @@
"title": "SyntheticDataGenerationResponse", "title": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
}, },
"UpdateProviderRequest": {
"type": "object",
"properties": {
"config": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"config"
],
"title": "UpdateProviderRequest"
},
"VersionInfo": { "VersionInfo": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -3484,6 +3484,50 @@ paths:
schema: schema:
$ref: '#/components/schemas/SyntheticDataGenerateRequest' $ref: '#/components/schemas/SyntheticDataGenerateRequest'
required: true required: true
/v1/providers/{api}/{provider_id}/{provider_type}:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ProviderInfo'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Providers
description: ''
parameters:
- name: api
in: path
required: true
schema:
type: string
- name: provider_id
in: path
required: true
schema:
type: string
- name: provider_type
in: path
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdateProviderRequest'
required: true
/v1/version: /v1/version:
get: get:
responses: responses:
@ -11234,6 +11278,23 @@ components:
description: >- description: >-
Response from the synthetic data generation. Batch of (prompt, response, score) Response from the synthetic data generation. Batch of (prompt, response, score)
tuples that pass the threshold. tuples that pass the threshold.
UpdateProviderRequest:
type: object
properties:
config:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- config
title: UpdateProviderRequest
VersionInfo: VersionInfo:
type: object type: object
properties: properties:

View file

@ -47,3 +47,8 @@ class Providers(Protocol):
:returns: A ProviderInfo object containing the provider's details. :returns: A ProviderInfo object containing the provider's details.
""" """
... ...
@webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT")
async def update_provider(
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
) -> ProviderInfo: ...

View file

@ -25,6 +25,7 @@ from llama_stack_client import (
AsyncStream, AsyncStream,
LlamaStackClient, LlamaStackClient,
) )
from llama_stack_client.types import provider_info
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from rich.console import Console from rich.console import Console
from termcolor import cprint from termcolor import cprint
@ -293,6 +294,22 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
) )
# Check if response is of a certain type
# this indicates we have done a provider update
if (
isinstance(response, provider_info.ProviderInfo)
and hasattr(response, "config")
and options.method.lower() == "put"
):
# patch in the new provider config
for api, providers in self.config.providers.items():
if api != response.api:
continue
for prov in providers:
if prov.provider_id == response.provider_id:
prov.config = response.config
break
await self.initialize()
return response return response
async def _call_non_streaming( async def _call_non_streaming(

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import copy
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -13,7 +14,7 @@ from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Prov
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig from .datatypes import Provider, StackRunConfig
from .utils.config import redact_sensitive_fields from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -129,3 +130,86 @@ class ProviderImpl(Providers):
providers_health[api_name] = health_response providers_health[api_name] = health_response
return providers_health return providers_health
async def update_provider(
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
) -> ProviderInfo:
# config = ast.literal_eval(provider_request.config)
prov = Provider(
provider_id=provider_id,
provider_type=provider_type,
config=config,
)
assert prov.provider_id is not None
existing_provider = None
# if the provider isn't there or the API is invalid, we should not continue
for prov_api, providers in self.config.run_config.providers.items():
if prov_api != api:
continue
for p in providers:
if p.provider_id == provider_id:
existing_provider = p
break
if existing_provider is not None:
break
if existing_provider is None:
raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")
new_config = self.merge_providers(existing_provider, prov)
existing_provider.config = new_config
providers_health = await self.get_providers_health()
# takes a single provider, validates its in the registry
# if it is, merge the provider config with the existing one
ret = ProviderInfo(
api=api,
provider_id=prov.provider_id,
provider_type=prov.provider_type,
config=new_config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
),
)
return ret
def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
"""Recursively merges `overrides` into `base`, replacing only specified keys."""
merged = copy.deepcopy(base) # Preserve original dict
for key, value in overrides.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
# Recursively merge if both are dictionaries
merged[key] = self.merge_dicts(merged[key], value)
else:
# Otherwise, directly override
merged[key] = value
return merged
def merge_configs(
self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
) -> dict[str, list[Provider]]:
merged_config = copy.deepcopy(global_config) # Preserve original structure
for key, new_providers in new_config.items():
if key in merged_config:
existing_providers = {p.provider_id: p for p in merged_config[key]}
for new_provider in new_providers:
if new_provider.provider_id in existing_providers:
# Override settings of existing provider
existing = existing_providers[new_provider.provider_id]
existing.config = self.merge_dicts(existing.config, new_provider.config)
else:
# Append new provider
merged_config[key].append(new_provider)
else:
# Add new category entirely
merged_config[key] = new_providers
return merged_config
def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
return self.merge_dicts(current_provider.config, new_provider.config)

View file

@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
http_method = hdrs.METH_GET http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE: elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE http_method = hdrs.METH_DELETE
elif webmethod.method == hdrs.METH_PUT:
http_method = hdrs.METH_PUT
else: else:
http_method = hdrs.METH_POST http_method = hdrs.METH_POST
routes.append( routes.append(

View file

@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from starlette.types import Message
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return route_handler return route_handler
class RequestMiddleware:
def __init__(self, app, api, stack_run_config):
self.app = app
self.api = api
self.stack_run_config = stack_run_config
async def __call__(self, scope, receive, send):
import json
from fastapi import Request
from llama_stack.apis.providers import ProviderInfo
# from llama_stack.stack_utils import construct_stack # or wherever you define it
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
method = request.method
path = request.url.path
# Only intercept PUT /v1/providers/update
if method == "PUT" and "/v1/providers" in path:
# Clone the request body so FastAPI doesn't break
body = await request.body()
request = Request(scope, receive_from_body(body))
response_body = b""
status_code = 500
headers = []
async def send_wrapper(message: Message):
nonlocal response_body, status_code, headers
if message["type"] == "http.response.start":
status_code = message["status"]
headers = message.get("headers", [])
elif message["type"] == "http.response.body":
response_body += message.get("body", b"")
if not message.get("more_body", False):
# Rebuild stack
try:
# Parse the request body (not response)
payload = json.loads(response_body.decode("utf-8"))
new_provider = ProviderInfo(**payload)
for api, providers in self.stack_run_config.providers.items():
if api != new_provider.api:
continue
for prov in providers:
if prov.provider_id == new_provider.provider_id:
prov.config = new_provider.config
break
_, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True)
self.api.__llama_stack_impls__ = impls
print("✅ Stack rebuilt and updated.")
except Exception as e:
print(f"⚠️ Failed to rebuild stack: {e}")
await send(message)
return await self.app(scope, request.receive, send_wrapper)
# All other requests go through normally
return await self.app(scope, receive, send)
# Helper to inject the saved body back into the request
def receive_from_body(body: bytes):
async def receive() -> Message:
return {"type": "http.request", "body": body, "more_body": False}
return receive
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app, impls): def __init__(self, app, impls):
self.app = app self.app = app
@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
try: apis_to_serve, impls = asyncio.run(construct(app=app, config=config))
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
method.lower(),
route.path,
)
)
logger.debug(f"serving APIs: {apis_to_serve}") logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
# Add the custom middleware
app.add_middleware(RequestMiddleware, api=app, stack_run_config=config)
app.add_middleware(TracingMiddleware, impls=impls) app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn import uvicorn
@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]:
return params return params
async def construct(
app: FastAPI, config: StackRunConfig, reconstruct: bool = False
) -> tuple[set[str] | set[Api], dict[Api, Any]]:
try:
impls = await construct_stack(config)
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
# Remove old route
if reconstruct:
app.router.routes = [
r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods)
]
new_endpoint = create_dynamic_typed_route(
impl_method,
method.lower(),
route.path, # route.path
)
getattr(app, method.lower())(route.path, response_model=None)(new_endpoint)
if reconstruct:
new_route = APIRoute(
response_model=None,
path=route.path,
endpoint=new_endpoint,
methods=[method.lower()],
name=impl_method.__name__,
)
# Add new route
app.router.routes.append(new_route)
return apis_to_serve, impls
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -21,3 +21,20 @@ class TestProviders:
pid = provider.provider_id pid = provider.provider_id
provider = llama_stack_client.providers.retrieve(pid) provider = llama_stack_client.providers.retrieve(pid)
assert provider is not None assert provider is not None
@pytest.mark.asyncio
def test_providers_update(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
new_cfg = {"url": "http://localhost:12345"}
_ = llama_stack_client.providers.retrieve("ollama")
llama_stack_client.providers.update(
api="inference",
provider_id="ollama",
provider_type="remote::ollama",
config=new_cfg,
)
new_provider = llama_stack_client.providers.retrieve("ollama")
assert new_provider.config == new_cfg