mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
feat: implement provider updating
add `v1/providers/` which uses PUT to allow users to change their provider configuration this is a follow up to #1429 and related to #1359 a user can call something like: `llama_stack_client.providers.update(api="inference", provider_id="ollama", provider_type="remote::ollama", config={'url': 'http:/localhost:12345'})` or `llama-stack-client providers update inference ollama remote::ollama "{'url': 'http://localhost:12345'}"` this API works by adding a `RequestMiddleware` to the server which checks requests, and if the user is using PUT /v1/providers, the routes are re-registered with the re-initialized provider configurations/methods for the client, `self.impls` is updated to hold the proper methods+configurations this depends on a client PR, the CI will fail until then but succeeded locally Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
d165000bbc
commit
436f8ade9e
8 changed files with 449 additions and 56 deletions
103
docs/_static/llama-stack-spec.html
vendored
103
docs/_static/llama-stack-spec.html
vendored
|
@ -4946,6 +4946,74 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/providers/{api}/{provider_id}/{provider_type}": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ProviderInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Providers"
|
||||
],
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "api",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider_type",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/UpdateProviderRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/version": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -16101,6 +16169,41 @@
|
|||
"title": "SyntheticDataGenerationResponse",
|
||||
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||
},
|
||||
"UpdateProviderRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"config"
|
||||
],
|
||||
"title": "UpdateProviderRequest"
|
||||
},
|
||||
"VersionInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
61
docs/_static/llama-stack-spec.yaml
vendored
61
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3484,6 +3484,50 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
||||
required: true
|
||||
/v1/providers/{api}/{provider_id}/{provider_type}:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProviderInfo'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Providers
|
||||
description: ''
|
||||
parameters:
|
||||
- name: api
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: provider_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: provider_type
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UpdateProviderRequest'
|
||||
required: true
|
||||
/v1/version:
|
||||
get:
|
||||
responses:
|
||||
|
@ -11234,6 +11278,23 @@ components:
|
|||
description: >-
|
||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
UpdateProviderRequest:
|
||||
type: object
|
||||
properties:
|
||||
config:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- config
|
||||
title: UpdateProviderRequest
|
||||
VersionInfo:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -47,3 +47,8 @@ class Providers(Protocol):
|
|||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT")
|
||||
async def update_provider(
|
||||
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
|
||||
) -> ProviderInfo: ...
|
||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack_client import (
|
|||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
from llama_stack_client.types import provider_info
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
@ -293,6 +294,22 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
# Check if response is of a certain type
|
||||
# this indicates we have done a provider update
|
||||
if (
|
||||
isinstance(response, provider_info.ProviderInfo)
|
||||
and hasattr(response, "config")
|
||||
and options.method.lower() == "put"
|
||||
):
|
||||
# patch in the new provider config
|
||||
for api, providers in self.config.providers.items():
|
||||
if api != response.api:
|
||||
continue
|
||||
for prov in providers:
|
||||
if prov.provider_id == response.provider_id:
|
||||
prov.config = response.config
|
||||
break
|
||||
await self.initialize()
|
||||
return response
|
||||
|
||||
async def _call_non_streaming(
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -13,7 +14,7 @@ from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Prov
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .datatypes import Provider, StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -129,3 +130,86 @@ class ProviderImpl(Providers):
|
|||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
|
||||
async def update_provider(
|
||||
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
|
||||
) -> ProviderInfo:
|
||||
# config = ast.literal_eval(provider_request.config)
|
||||
prov = Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
)
|
||||
assert prov.provider_id is not None
|
||||
existing_provider = None
|
||||
# if the provider isn't there or the API is invalid, we should not continue
|
||||
for prov_api, providers in self.config.run_config.providers.items():
|
||||
if prov_api != api:
|
||||
continue
|
||||
for p in providers:
|
||||
if p.provider_id == provider_id:
|
||||
existing_provider = p
|
||||
break
|
||||
if existing_provider is not None:
|
||||
break
|
||||
|
||||
if existing_provider is None:
|
||||
raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")
|
||||
|
||||
new_config = self.merge_providers(existing_provider, prov)
|
||||
existing_provider.config = new_config
|
||||
providers_health = await self.get_providers_health()
|
||||
# takes a single provider, validates its in the registry
|
||||
# if it is, merge the provider config with the existing one
|
||||
ret = ProviderInfo(
|
||||
api=api,
|
||||
provider_id=prov.provider_id,
|
||||
provider_type=prov.provider_type,
|
||||
config=new_config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
|
||||
),
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively merges `overrides` into `base`, replacing only specified keys."""
|
||||
|
||||
merged = copy.deepcopy(base) # Preserve original dict
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
# Recursively merge if both are dictionaries
|
||||
merged[key] = self.merge_dicts(merged[key], value)
|
||||
else:
|
||||
# Otherwise, directly override
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
def merge_configs(
|
||||
self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
|
||||
) -> dict[str, list[Provider]]:
|
||||
merged_config = copy.deepcopy(global_config) # Preserve original structure
|
||||
|
||||
for key, new_providers in new_config.items():
|
||||
if key in merged_config:
|
||||
existing_providers = {p.provider_id: p for p in merged_config[key]}
|
||||
|
||||
for new_provider in new_providers:
|
||||
if new_provider.provider_id in existing_providers:
|
||||
# Override settings of existing provider
|
||||
existing = existing_providers[new_provider.provider_id]
|
||||
existing.config = self.merge_dicts(existing.config, new_provider.config)
|
||||
else:
|
||||
# Append new provider
|
||||
merged_config[key].append(new_provider)
|
||||
else:
|
||||
# Add new category entirely
|
||||
merged_config[key] = new_providers
|
||||
|
||||
return merged_config
|
||||
|
||||
def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
|
||||
return self.merge_dicts(current_provider.config, new_provider.config)
|
||||
|
|
|
@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
elif webmethod.method == hdrs.METH_PUT:
|
||||
http_method = hdrs.METH_PUT
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
|
|
|
@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
|||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.types import Message
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
|
@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
return route_handler
|
||||
|
||||
|
||||
class RequestMiddleware:
|
||||
def __init__(self, app, api, stack_run_config):
|
||||
self.app = app
|
||||
self.api = api
|
||||
self.stack_run_config = stack_run_config
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
import json
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from llama_stack.apis.providers import ProviderInfo
|
||||
|
||||
# from llama_stack.stack_utils import construct_stack # or wherever you define it
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope, receive)
|
||||
method = request.method
|
||||
path = request.url.path
|
||||
|
||||
# Only intercept PUT /v1/providers/update
|
||||
if method == "PUT" and "/v1/providers" in path:
|
||||
# Clone the request body so FastAPI doesn't break
|
||||
|
||||
body = await request.body()
|
||||
request = Request(scope, receive_from_body(body))
|
||||
|
||||
response_body = b""
|
||||
status_code = 500
|
||||
headers = []
|
||||
|
||||
async def send_wrapper(message: Message):
|
||||
nonlocal response_body, status_code, headers
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
headers = message.get("headers", [])
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
response_body += message.get("body", b"")
|
||||
|
||||
if not message.get("more_body", False):
|
||||
# Rebuild stack
|
||||
try:
|
||||
# Parse the request body (not response)
|
||||
payload = json.loads(response_body.decode("utf-8"))
|
||||
new_provider = ProviderInfo(**payload)
|
||||
for api, providers in self.stack_run_config.providers.items():
|
||||
if api != new_provider.api:
|
||||
continue
|
||||
for prov in providers:
|
||||
if prov.provider_id == new_provider.provider_id:
|
||||
prov.config = new_provider.config
|
||||
break
|
||||
|
||||
_, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True)
|
||||
self.api.__llama_stack_impls__ = impls
|
||||
print("✅ Stack rebuilt and updated.")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to rebuild stack: {e}")
|
||||
|
||||
await send(message)
|
||||
|
||||
return await self.app(scope, request.receive, send_wrapper)
|
||||
|
||||
# All other requests go through normally
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
# Helper to inject the saved body back into the request
|
||||
def receive_from_body(body: bytes):
|
||||
async def receive() -> Message:
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
return receive
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
|
@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_routes = get_all_api_routes()
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||
if inf.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
|
||||
for route in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
||||
impl_method = getattr(impl, route.name)
|
||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, method.lower())(route.path, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
method.lower(),
|
||||
route.path,
|
||||
)
|
||||
)
|
||||
|
||||
apis_to_serve, impls = asyncio.run(construct(app=app, config=config))
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
# Add the custom middleware
|
||||
app.add_middleware(RequestMiddleware, api=app, stack_run_config=config)
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
|
||||
import uvicorn
|
||||
|
@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]:
|
|||
return params
|
||||
|
||||
|
||||
async def construct(
|
||||
app: FastAPI, config: StackRunConfig, reconstruct: bool = False
|
||||
) -> tuple[set[str] | set[Api], dict[Api, Any]]:
|
||||
try:
|
||||
impls = await construct_stack(config)
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_routes = get_all_api_routes()
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||
if inf.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
for route in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, route.name)
|
||||
|
||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
# Remove old route
|
||||
if reconstruct:
|
||||
app.router.routes = [
|
||||
r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods)
|
||||
]
|
||||
new_endpoint = create_dynamic_typed_route(
|
||||
impl_method,
|
||||
method.lower(),
|
||||
route.path, # route.path
|
||||
)
|
||||
getattr(app, method.lower())(route.path, response_model=None)(new_endpoint)
|
||||
if reconstruct:
|
||||
new_route = APIRoute(
|
||||
response_model=None,
|
||||
path=route.path,
|
||||
endpoint=new_endpoint,
|
||||
methods=[method.lower()],
|
||||
name=impl_method.__name__,
|
||||
)
|
||||
# Add new route
|
||||
app.router.routes.append(new_route)
|
||||
|
||||
return apis_to_serve, impls
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -21,3 +21,20 @@ class TestProviders:
|
|||
pid = provider.provider_id
|
||||
provider = llama_stack_client.providers.retrieve(pid)
|
||||
assert provider is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_providers_update(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
new_cfg = {"url": "http://localhost:12345"}
|
||||
|
||||
_ = llama_stack_client.providers.retrieve("ollama")
|
||||
|
||||
llama_stack_client.providers.update(
|
||||
api="inference",
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=new_cfg,
|
||||
)
|
||||
|
||||
new_provider = llama_stack_client.providers.retrieve("ollama")
|
||||
|
||||
assert new_provider.config == new_cfg
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue