mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: implement provider updating
add `v1/providers/` which uses PUT to allow users to change their provider configuration this is a follow up to #1429 and related to #1359 a user can call something like: `llama_stack_client.providers.update(api="inference", provider_id="ollama", provider_type="remote::ollama", config={'url': 'http:/localhost:12345'})` or `llama-stack-client providers update inference ollama remote::ollama "{'url': 'http://localhost:12345'}"` this API works by adding a `RequestMiddleware` to the server which checks requests, and if the user is using PUT /v1/providers, the routes are re-registered with the re-initialized provider configurations/methods for the client, `self.impls` is updated to hold the proper methods+configurations this depends on a client PR, the CI will fail until then but succeeded locally Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
d165000bbc
commit
436f8ade9e
8 changed files with 449 additions and 56 deletions
103
docs/_static/llama-stack-spec.html
vendored
103
docs/_static/llama-stack-spec.html
vendored
|
@ -4946,6 +4946,74 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/providers/{api}/{provider_id}/{provider_type}": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ProviderInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Providers"
|
||||||
|
],
|
||||||
|
"description": "",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "api",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_type",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/UpdateProviderRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/version": {
|
"/v1/version": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -16101,6 +16169,41 @@
|
||||||
"title": "SyntheticDataGenerationResponse",
|
"title": "SyntheticDataGenerationResponse",
|
||||||
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||||
},
|
},
|
||||||
|
"UpdateProviderRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"config"
|
||||||
|
],
|
||||||
|
"title": "UpdateProviderRequest"
|
||||||
|
},
|
||||||
"VersionInfo": {
|
"VersionInfo": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
61
docs/_static/llama-stack-spec.yaml
vendored
61
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3484,6 +3484,50 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/providers/{api}/{provider_id}/{provider_type}:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ProviderInfo'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Providers
|
||||||
|
description: ''
|
||||||
|
parameters:
|
||||||
|
- name: api
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: provider_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: provider_type
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/UpdateProviderRequest'
|
||||||
|
required: true
|
||||||
/v1/version:
|
/v1/version:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -11234,6 +11278,23 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
Response from the synthetic data generation. Batch of (prompt, response, score)
|
||||||
tuples that pass the threshold.
|
tuples that pass the threshold.
|
||||||
|
UpdateProviderRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
config:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- config
|
||||||
|
title: UpdateProviderRequest
|
||||||
VersionInfo:
|
VersionInfo:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -47,3 +47,8 @@ class Providers(Protocol):
|
||||||
:returns: A ProviderInfo object containing the provider's details.
|
:returns: A ProviderInfo object containing the provider's details.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/providers/{api}/{provider_id}/{provider_type}", method="PUT")
|
||||||
|
async def update_provider(
|
||||||
|
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
|
||||||
|
) -> ProviderInfo: ...
|
||||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack_client import (
|
||||||
AsyncStream,
|
AsyncStream,
|
||||||
LlamaStackClient,
|
LlamaStackClient,
|
||||||
)
|
)
|
||||||
|
from llama_stack_client.types import provider_info
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -293,6 +294,22 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
# Check if response is of a certain type
|
||||||
|
# this indicates we have done a provider update
|
||||||
|
if (
|
||||||
|
isinstance(response, provider_info.ProviderInfo)
|
||||||
|
and hasattr(response, "config")
|
||||||
|
and options.method.lower() == "put"
|
||||||
|
):
|
||||||
|
# patch in the new provider config
|
||||||
|
for api, providers in self.config.providers.items():
|
||||||
|
if api != response.api:
|
||||||
|
continue
|
||||||
|
for prov in providers:
|
||||||
|
if prov.provider_id == response.provider_id:
|
||||||
|
prov.config = response.config
|
||||||
|
break
|
||||||
|
await self.initialize()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -13,7 +14,7 @@ from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Prov
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||||
|
|
||||||
from .datatypes import StackRunConfig
|
from .datatypes import Provider, StackRunConfig
|
||||||
from .utils.config import redact_sensitive_fields
|
from .utils.config import redact_sensitive_fields
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -129,3 +130,86 @@ class ProviderImpl(Providers):
|
||||||
providers_health[api_name] = health_response
|
providers_health[api_name] = health_response
|
||||||
|
|
||||||
return providers_health
|
return providers_health
|
||||||
|
|
||||||
|
async def update_provider(
|
||||||
|
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
|
||||||
|
) -> ProviderInfo:
|
||||||
|
# config = ast.literal_eval(provider_request.config)
|
||||||
|
prov = Provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=provider_type,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
assert prov.provider_id is not None
|
||||||
|
existing_provider = None
|
||||||
|
# if the provider isn't there or the API is invalid, we should not continue
|
||||||
|
for prov_api, providers in self.config.run_config.providers.items():
|
||||||
|
if prov_api != api:
|
||||||
|
continue
|
||||||
|
for p in providers:
|
||||||
|
if p.provider_id == provider_id:
|
||||||
|
existing_provider = p
|
||||||
|
break
|
||||||
|
if existing_provider is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_provider is None:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")
|
||||||
|
|
||||||
|
new_config = self.merge_providers(existing_provider, prov)
|
||||||
|
existing_provider.config = new_config
|
||||||
|
providers_health = await self.get_providers_health()
|
||||||
|
# takes a single provider, validates its in the registry
|
||||||
|
# if it is, merge the provider config with the existing one
|
||||||
|
ret = ProviderInfo(
|
||||||
|
api=api,
|
||||||
|
provider_id=prov.provider_id,
|
||||||
|
provider_type=prov.provider_type,
|
||||||
|
config=new_config,
|
||||||
|
health=providers_health.get(api, {}).get(
|
||||||
|
p.provider_id,
|
||||||
|
HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Recursively merges `overrides` into `base`, replacing only specified keys."""
|
||||||
|
|
||||||
|
merged = copy.deepcopy(base) # Preserve original dict
|
||||||
|
for key, value in overrides.items():
|
||||||
|
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||||
|
# Recursively merge if both are dictionaries
|
||||||
|
merged[key] = self.merge_dicts(merged[key], value)
|
||||||
|
else:
|
||||||
|
# Otherwise, directly override
|
||||||
|
merged[key] = value
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def merge_configs(
|
||||||
|
self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
|
||||||
|
) -> dict[str, list[Provider]]:
|
||||||
|
merged_config = copy.deepcopy(global_config) # Preserve original structure
|
||||||
|
|
||||||
|
for key, new_providers in new_config.items():
|
||||||
|
if key in merged_config:
|
||||||
|
existing_providers = {p.provider_id: p for p in merged_config[key]}
|
||||||
|
|
||||||
|
for new_provider in new_providers:
|
||||||
|
if new_provider.provider_id in existing_providers:
|
||||||
|
# Override settings of existing provider
|
||||||
|
existing = existing_providers[new_provider.provider_id]
|
||||||
|
existing.config = self.merge_dicts(existing.config, new_provider.config)
|
||||||
|
else:
|
||||||
|
# Append new provider
|
||||||
|
merged_config[key].append(new_provider)
|
||||||
|
else:
|
||||||
|
# Add new category entirely
|
||||||
|
merged_config[key] = new_providers
|
||||||
|
|
||||||
|
return merged_config
|
||||||
|
|
||||||
|
def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
|
||||||
|
return self.merge_dicts(current_provider.config, new_provider.config)
|
||||||
|
|
|
@ -62,6 +62,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
http_method = hdrs.METH_GET
|
http_method = hdrs.METH_GET
|
||||||
elif webmethod.method == hdrs.METH_DELETE:
|
elif webmethod.method == hdrs.METH_DELETE:
|
||||||
http_method = hdrs.METH_DELETE
|
http_method = hdrs.METH_DELETE
|
||||||
|
elif webmethod.method == hdrs.METH_PUT:
|
||||||
|
http_method = hdrs.METH_PUT
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
http_method = hdrs.METH_POST
|
||||||
routes.append(
|
routes.append(
|
||||||
|
|
|
@ -27,8 +27,10 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from starlette.types import Message
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||||
|
@ -269,6 +271,84 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
return route_handler
|
return route_handler
|
||||||
|
|
||||||
|
|
||||||
|
class RequestMiddleware:
|
||||||
|
def __init__(self, app, api, stack_run_config):
|
||||||
|
self.app = app
|
||||||
|
self.api = api
|
||||||
|
self.stack_run_config = stack_run_config
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from llama_stack.apis.providers import ProviderInfo
|
||||||
|
|
||||||
|
# from llama_stack.stack_utils import construct_stack # or wherever you define it
|
||||||
|
if scope["type"] != "http":
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
request = Request(scope, receive)
|
||||||
|
method = request.method
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Only intercept PUT /v1/providers/update
|
||||||
|
if method == "PUT" and "/v1/providers" in path:
|
||||||
|
# Clone the request body so FastAPI doesn't break
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
request = Request(scope, receive_from_body(body))
|
||||||
|
|
||||||
|
response_body = b""
|
||||||
|
status_code = 500
|
||||||
|
headers = []
|
||||||
|
|
||||||
|
async def send_wrapper(message: Message):
|
||||||
|
nonlocal response_body, status_code, headers
|
||||||
|
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
status_code = message["status"]
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
|
||||||
|
elif message["type"] == "http.response.body":
|
||||||
|
response_body += message.get("body", b"")
|
||||||
|
|
||||||
|
if not message.get("more_body", False):
|
||||||
|
# Rebuild stack
|
||||||
|
try:
|
||||||
|
# Parse the request body (not response)
|
||||||
|
payload = json.loads(response_body.decode("utf-8"))
|
||||||
|
new_provider = ProviderInfo(**payload)
|
||||||
|
for api, providers in self.stack_run_config.providers.items():
|
||||||
|
if api != new_provider.api:
|
||||||
|
continue
|
||||||
|
for prov in providers:
|
||||||
|
if prov.provider_id == new_provider.provider_id:
|
||||||
|
prov.config = new_provider.config
|
||||||
|
break
|
||||||
|
|
||||||
|
_, impls = await construct(app=self.api, config=self.stack_run_config, reconstruct=True)
|
||||||
|
self.api.__llama_stack_impls__ = impls
|
||||||
|
print("✅ Stack rebuilt and updated.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to rebuild stack: {e}")
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
return await self.app(scope, request.receive, send_wrapper)
|
||||||
|
|
||||||
|
# All other requests go through normally
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper to inject the saved body back into the request
|
||||||
|
def receive_from_body(body: bytes):
|
||||||
|
async def receive() -> Message:
|
||||||
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
|
||||||
|
return receive
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
def __init__(self, app, impls):
|
def __init__(self, app, impls):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
@ -482,67 +562,15 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
apis_to_serve, impls = asyncio.run(construct(app=app, config=config))
|
||||||
impls = asyncio.run(construct_stack(config))
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
|
||||||
setup_logger(impls[Api.telemetry])
|
|
||||||
else:
|
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
|
||||||
|
|
||||||
all_routes = get_all_api_routes()
|
|
||||||
|
|
||||||
if config.apis:
|
|
||||||
apis_to_serve = set(config.apis)
|
|
||||||
else:
|
|
||||||
apis_to_serve = set(impls.keys())
|
|
||||||
|
|
||||||
for inf in builtin_automatically_routed_apis():
|
|
||||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
|
||||||
if inf.router_api.value not in apis_to_serve:
|
|
||||||
continue
|
|
||||||
apis_to_serve.add(inf.routing_table_api.value)
|
|
||||||
|
|
||||||
apis_to_serve.add("inspect")
|
|
||||||
apis_to_serve.add("providers")
|
|
||||||
for api_str in apis_to_serve:
|
|
||||||
api = Api(api_str)
|
|
||||||
|
|
||||||
routes = all_routes[api]
|
|
||||||
impl = impls[api]
|
|
||||||
|
|
||||||
for route in routes:
|
|
||||||
if not hasattr(impl, route.name):
|
|
||||||
# ideally this should be a typing violation already
|
|
||||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
|
||||||
|
|
||||||
impl_method = getattr(impl, route.name)
|
|
||||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
|
||||||
if not available_methods:
|
|
||||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
|
||||||
method = available_methods[0]
|
|
||||||
logger.debug(f"{method} {route.path}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
|
||||||
getattr(app, method.lower())(route.path, response_model=None)(
|
|
||||||
create_dynamic_typed_route(
|
|
||||||
impl_method,
|
|
||||||
method.lower(),
|
|
||||||
route.path,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||||
|
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
|
# Add the custom middleware
|
||||||
|
app.add_middleware(RequestMiddleware, api=app, stack_run_config=config)
|
||||||
app.add_middleware(TracingMiddleware, impls=impls)
|
app.add_middleware(TracingMiddleware, impls=impls)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
@ -592,5 +620,81 @@ def extract_path_params(route: str) -> list[str]:
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
async def construct(
|
||||||
|
app: FastAPI, config: StackRunConfig, reconstruct: bool = False
|
||||||
|
) -> tuple[set[str] | set[Api], dict[Api, Any]]:
|
||||||
|
try:
|
||||||
|
impls = await construct_stack(config)
|
||||||
|
except InvalidProviderError as e:
|
||||||
|
logger.error(f"Error: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if Api.telemetry in impls:
|
||||||
|
setup_logger(impls[Api.telemetry])
|
||||||
|
else:
|
||||||
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
|
all_routes = get_all_api_routes()
|
||||||
|
|
||||||
|
if config.apis:
|
||||||
|
apis_to_serve = set(config.apis)
|
||||||
|
else:
|
||||||
|
apis_to_serve = set(impls.keys())
|
||||||
|
|
||||||
|
for inf in builtin_automatically_routed_apis():
|
||||||
|
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||||
|
if inf.router_api.value not in apis_to_serve:
|
||||||
|
continue
|
||||||
|
apis_to_serve.add(inf.routing_table_api.value)
|
||||||
|
|
||||||
|
apis_to_serve.add("inspect")
|
||||||
|
apis_to_serve.add("providers")
|
||||||
|
|
||||||
|
for api_str in apis_to_serve:
|
||||||
|
api = Api(api_str)
|
||||||
|
|
||||||
|
routes = all_routes[api]
|
||||||
|
impl = impls[api]
|
||||||
|
for route in routes:
|
||||||
|
if not hasattr(impl, route.name):
|
||||||
|
# ideally this should be a typing violation already
|
||||||
|
raise ValueError(f"Could not find method {route.name} on {impl}!!")
|
||||||
|
|
||||||
|
impl_method = getattr(impl, route.name)
|
||||||
|
|
||||||
|
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||||
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||||
|
if not available_methods:
|
||||||
|
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||||
|
method = available_methods[0]
|
||||||
|
logger.debug(f"{method} {route.path}")
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
# Remove old route
|
||||||
|
if reconstruct:
|
||||||
|
app.router.routes = [
|
||||||
|
r for r in app.router.routes if not (r.path == route.path and method.lower() in r.methods)
|
||||||
|
]
|
||||||
|
new_endpoint = create_dynamic_typed_route(
|
||||||
|
impl_method,
|
||||||
|
method.lower(),
|
||||||
|
route.path, # route.path
|
||||||
|
)
|
||||||
|
getattr(app, method.lower())(route.path, response_model=None)(new_endpoint)
|
||||||
|
if reconstruct:
|
||||||
|
new_route = APIRoute(
|
||||||
|
response_model=None,
|
||||||
|
path=route.path,
|
||||||
|
endpoint=new_endpoint,
|
||||||
|
methods=[method.lower()],
|
||||||
|
name=impl_method.__name__,
|
||||||
|
)
|
||||||
|
# Add new route
|
||||||
|
app.router.routes.append(new_route)
|
||||||
|
|
||||||
|
return apis_to_serve, impls
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -21,3 +21,20 @@ class TestProviders:
|
||||||
pid = provider.provider_id
|
pid = provider.provider_id
|
||||||
provider = llama_stack_client.providers.retrieve(pid)
|
provider = llama_stack_client.providers.retrieve(pid)
|
||||||
assert provider is not None
|
assert provider is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
def test_providers_update(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
new_cfg = {"url": "http://localhost:12345"}
|
||||||
|
|
||||||
|
_ = llama_stack_client.providers.retrieve("ollama")
|
||||||
|
|
||||||
|
llama_stack_client.providers.update(
|
||||||
|
api="inference",
|
||||||
|
provider_id="ollama",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config=new_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_provider = llama_stack_client.providers.retrieve("ollama")
|
||||||
|
|
||||||
|
assert new_provider.config == new_cfg
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue