mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
feat: Add support for dynamically managing provider connections
This commit is contained in:
parent
63422e5b36
commit
d11edf6fee
9 changed files with 3176 additions and 8 deletions
117
llama_stack/apis/providers/connection.py
Normal file
117
llama_stack/apis/providers/connection.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderConnectionStatus(StrEnum):
|
||||
"""Status of a dynamic provider connection.
|
||||
|
||||
:cvar pending: Configuration stored, not yet initialized
|
||||
:cvar initializing: In the process of connecting
|
||||
:cvar connected: Successfully connected and healthy
|
||||
:cvar failed: Connection attempt failed
|
||||
:cvar disconnected: Previously connected, now disconnected
|
||||
:cvar testing: Health check in progress
|
||||
"""
|
||||
|
||||
pending = "pending"
|
||||
initializing = "initializing"
|
||||
connected = "connected"
|
||||
failed = "failed"
|
||||
disconnected = "disconnected"
|
||||
testing = "testing"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderHealth(BaseModel):
|
||||
"""Structured wrapper around provider health status.
|
||||
|
||||
This wraps the existing dict-based HealthResponse for API responses
|
||||
while maintaining backward compatibility with existing provider implementations.
|
||||
|
||||
:param status: Health status (OK, ERROR, NOT_IMPLEMENTED)
|
||||
:param message: Optional error or status message
|
||||
:param metrics: Provider-specific health metrics
|
||||
:param last_checked: Timestamp of last health check
|
||||
"""
|
||||
|
||||
status: HealthStatus
|
||||
message: str | None = None
|
||||
metrics: dict[str, Any] = Field(default_factory=dict)
|
||||
last_checked: datetime
|
||||
|
||||
@classmethod
|
||||
def from_health_response(cls, response: dict[str, Any]) -> "ProviderHealth":
|
||||
"""Convert dict-based HealthResponse to ProviderHealth.
|
||||
|
||||
This allows us to maintain the existing dict[str, Any] return type
|
||||
for provider.health() methods while providing a structured model
|
||||
for API responses.
|
||||
|
||||
:param response: Dict with 'status' and optional 'message', 'metrics'
|
||||
:returns: ProviderHealth instance
|
||||
"""
|
||||
return cls(
|
||||
status=HealthStatus(response.get("status", HealthStatus.NOT_IMPLEMENTED)),
|
||||
message=response.get("message"),
|
||||
metrics=response.get("metrics", {}),
|
||||
last_checked=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderConnectionInfo(BaseModel):
|
||||
"""Information about a dynamically managed provider connection.
|
||||
|
||||
This model represents a provider that has been registered at runtime
|
||||
via the /providers API, as opposed to static providers configured in run.yaml.
|
||||
|
||||
Dynamic providers support full lifecycle management including registration,
|
||||
configuration updates, health monitoring, and removal.
|
||||
|
||||
:param provider_id: Unique identifier for this provider instance
|
||||
:param api: API namespace (e.g., "inference", "vector_io", "safety")
|
||||
:param provider_type: Provider type identifier (e.g., "remote::openai", "inline::faiss")
|
||||
:param config: Provider-specific configuration (API keys, endpoints, etc.)
|
||||
:param status: Current connection status
|
||||
:param health: Most recent health check result
|
||||
:param created_at: Timestamp when provider was registered
|
||||
:param updated_at: Timestamp of last update
|
||||
:param last_health_check: Timestamp of last health check
|
||||
:param error_message: Error message if status is failed
|
||||
:param metadata: User-defined metadata (deprecated, use attributes)
|
||||
:param owner: User who created this provider connection
|
||||
:param attributes: Key-value attributes for ABAC access control
|
||||
"""
|
||||
|
||||
provider_id: str
|
||||
api: str
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
status: ProviderConnectionStatus
|
||||
health: ProviderHealth | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_health_check: datetime | None = None
|
||||
error_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Deprecated: use attributes for access control",
|
||||
)
|
||||
|
||||
# ABAC fields (same as ResourceWithOwner)
|
||||
owner: User | None = None
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
|
@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers.connection import ProviderConnectionInfo
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
|
@ -40,6 +41,85 @@ class ListProvidersResponse(BaseModel):
|
|||
data: list[ProviderInfo]
|
||||
|
||||
|
||||
# ===== Dynamic Provider Management API Models =====
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegisterProviderRequest(BaseModel):
|
||||
"""Request to register a new dynamic provider.
|
||||
|
||||
:param provider_id: Unique identifier for the provider instance
|
||||
:param api: API namespace (e.g., 'inference', 'vector_io', 'safety')
|
||||
:param provider_type: Provider type identifier (e.g., 'remote::openai', 'inline::faiss')
|
||||
:param config: Provider-specific configuration (API keys, endpoints, etc.)
|
||||
:param attributes: Optional key-value attributes for ABAC access control
|
||||
"""
|
||||
|
||||
provider_id: str
|
||||
api: str
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegisterProviderResponse(BaseModel):
|
||||
"""Response after registering a provider.
|
||||
|
||||
:param provider: Information about the registered provider
|
||||
"""
|
||||
|
||||
provider: ProviderConnectionInfo
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UpdateProviderRequest(BaseModel):
|
||||
"""Request to update an existing provider's configuration.
|
||||
|
||||
:param config: New configuration parameters (will be merged with existing)
|
||||
:param attributes: Optional updated attributes for access control
|
||||
"""
|
||||
|
||||
config: dict[str, Any] | None = None
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UpdateProviderResponse(BaseModel):
|
||||
"""Response after updating a provider.
|
||||
|
||||
:param provider: Updated provider information
|
||||
"""
|
||||
|
||||
provider: ProviderConnectionInfo
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnregisterProviderResponse(BaseModel):
|
||||
"""Response after unregistering a provider.
|
||||
|
||||
:param success: Whether the operation succeeded
|
||||
:param message: Optional status message
|
||||
"""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TestProviderConnectionResponse(BaseModel):
|
||||
"""Response from testing a provider connection.
|
||||
|
||||
:param success: Whether the connection test succeeded
|
||||
:param health: Health status from the provider
|
||||
:param error_message: Error message if test failed
|
||||
"""
|
||||
|
||||
success: bool
|
||||
health: HealthResponse | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""Providers
|
||||
|
|
@ -67,3 +147,71 @@ class Providers(Protocol):
|
|||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
|
||||
# ===== Dynamic Provider Management Methods =====
|
||||
|
||||
@webmethod(route="/admin/providers", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_provider(
|
||||
self,
|
||||
provider_id: str,
|
||||
api: str,
|
||||
provider_type: str,
|
||||
config: dict[str, Any],
|
||||
attributes: dict[str, list[str]] | None = None,
|
||||
) -> RegisterProviderResponse:
|
||||
"""Register a new dynamic provider.
|
||||
|
||||
Register a new provider instance at runtime. The provider will be validated,
|
||||
instantiated, and persisted to the kvstore. Requires appropriate ABAC permissions.
|
||||
|
||||
:param provider_id: Unique identifier for this provider instance.
|
||||
:param api: API namespace this provider implements.
|
||||
:param provider_type: Provider type (e.g., 'remote::openai').
|
||||
:param config: Provider configuration (API keys, endpoints, etc.).
|
||||
:param attributes: Optional attributes for ABAC access control.
|
||||
:returns: RegisterProviderResponse with the registered provider info.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/admin/providers/{provider_id}", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def update_provider(
|
||||
self,
|
||||
provider_id: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
attributes: dict[str, list[str]] | None = None,
|
||||
) -> UpdateProviderResponse:
|
||||
"""Update an existing provider's configuration.
|
||||
|
||||
Update the configuration and/or attributes of a dynamic provider. The provider
|
||||
will be re-instantiated with the new configuration (hot-reload). Static providers
|
||||
from run.yaml cannot be updated.
|
||||
|
||||
:param provider_id: ID of the provider to update
|
||||
:param config: New configuration parameters (merged with existing)
|
||||
:param attributes: New attributes for access control
|
||||
:returns: UpdateProviderResponse with updated provider info
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/admin/providers/{provider_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_provider(self, provider_id: str) -> None:
|
||||
"""Unregister a dynamic provider.
|
||||
|
||||
Remove a dynamic provider, shutting down its instance and removing it from
|
||||
the kvstore. Static providers from run.yaml cannot be unregistered.
|
||||
|
||||
:param provider_id: ID of the provider to unregister.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}/test", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
|
||||
"""Test a provider connection.
|
||||
|
||||
Execute a health check on a provider to verify it is reachable and functioning.
|
||||
Works for both static and dynamic providers.
|
||||
|
||||
:param provider_id: ID of the provider to test.
|
||||
:returns: TestProviderConnectionResponse with health status.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue