mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 19:38:42 +00:00
feat: migrate Providers API to FastAPI router pattern (#4405)
# What does this PR do? Convert Providers API from @webmethod decorators to FastAPI router pattern. Fixes: https://github.com/llamastack/llama-stack/issues/4350 <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan CI Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
722d9c53e7
commit
cd5095a247
13 changed files with 287 additions and 121 deletions
|
|
@ -10,7 +10,14 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack_api import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
InspectProviderRequest,
|
||||
ListProvidersResponse,
|
||||
ProviderInfo,
|
||||
Providers,
|
||||
)
|
||||
|
||||
from .datatypes import StackConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
|
@ -67,13 +74,13 @@ class ProviderImpl(Providers):
|
|||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
async def inspect_provider(self, request: InspectProviderRequest) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
if p.provider_id == request.provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
raise ValueError(f"Provider {request.provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from fastapi import APIRouter
|
|||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack_api import batches, benchmarks, datasets
|
||||
from llama_stack_api import batches, benchmarks, datasets, providers
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
|
|
@ -27,6 +27,7 @@ _ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
|
|||
"batches": batches.fastapi_routes.create_router,
|
||||
"benchmarks": benchmarks.fastapi_routes.create_router,
|
||||
"datasets": datasets.fastapi_routes.create_router,
|
||||
"providers": providers.fastapi_routes.create_router,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -359,7 +359,12 @@ from .post_training import (
|
|||
TrainingConfig,
|
||||
)
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from .providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from .providers import (
|
||||
InspectProviderRequest,
|
||||
ListProvidersResponse,
|
||||
ProviderInfo,
|
||||
Providers,
|
||||
)
|
||||
from .rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
|
@ -611,6 +616,7 @@ __all__ = [
|
|||
"ListPostTrainingJobsResponse",
|
||||
"ListPromptsResponse",
|
||||
"ListProvidersResponse",
|
||||
"InspectProviderRequest",
|
||||
"ListRoutesResponse",
|
||||
"ListScoringFunctionsResponse",
|
||||
"ListShieldsResponse",
|
||||
|
|
|
|||
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack_api.datatypes import HealthResponse
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about a registered provider including its configuration and health status.
|
||||
|
||||
:param api: The API name this provider implements
|
||||
:param provider_id: Unique identifier for the provider
|
||||
:param provider_type: The type of provider implementation
|
||||
:param config: Configuration parameters for the provider
|
||||
:param health: Current health status of the provider
|
||||
"""
|
||||
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListProvidersResponse(BaseModel):
|
||||
"""Response containing a list of all available providers.
|
||||
|
||||
:param data: List of provider information objects
|
||||
"""
|
||||
|
||||
data: list[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
33
src/llama_stack_api/providers/__init__.py
Normal file
33
src/llama_stack_api/providers/__init__.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Providers API protocol and models.
|
||||
|
||||
This module contains the Providers protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.providers.models.
|
||||
The FastAPI router is defined in llama_stack_api.providers.fastapi_routes.
|
||||
"""
|
||||
|
||||
# Import fastapi_routes for router factory access
|
||||
from . import fastapi_routes
|
||||
|
||||
# Import protocol for re-export
|
||||
from .api import Providers
|
||||
|
||||
# Import models for re-export
|
||||
from .models import (
|
||||
InspectProviderRequest,
|
||||
ListProvidersResponse,
|
||||
ProviderInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Providers",
|
||||
"ProviderInfo",
|
||||
"ListProvidersResponse",
|
||||
"InspectProviderRequest",
|
||||
"fastapi_routes",
|
||||
]
|
||||
16
src/llama_stack_api/providers/api.py
Normal file
16
src/llama_stack_api/providers/api.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from .models import InspectProviderRequest, ListProvidersResponse, ProviderInfo
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
async def inspect_provider(self, request: InspectProviderRequest) -> ProviderInfo: ...
|
||||
57
src/llama_stack_api/providers/fastapi_routes.py
Normal file
57
src/llama_stack_api/providers/fastapi_routes.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""FastAPI router for the Providers API.
|
||||
|
||||
This module defines the FastAPI router for the Providers API using standard
|
||||
FastAPI route decorators.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from llama_stack_api.router_utils import create_path_dependency, standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
from .api import Providers
|
||||
from .models import InspectProviderRequest, ListProvidersResponse, ProviderInfo
|
||||
|
||||
# Path parameter dependencies for single-field models
|
||||
get_inspect_provider_request = create_path_dependency(InspectProviderRequest)
|
||||
|
||||
|
||||
def create_router(impl: Providers) -> APIRouter:
|
||||
"""Create a FastAPI router for the Providers API."""
|
||||
router = APIRouter(
|
||||
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||
tags=["Providers"],
|
||||
responses=standard_responses,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
response_model=ListProvidersResponse,
|
||||
summary="List providers.",
|
||||
description="List all available providers.",
|
||||
responses={200: {"description": "A ListProvidersResponse containing information about all providers."}},
|
||||
)
|
||||
async def list_providers() -> ListProvidersResponse:
|
||||
return await impl.list_providers()
|
||||
|
||||
@router.get(
|
||||
"/providers/{provider_id}",
|
||||
response_model=ProviderInfo,
|
||||
summary="Get provider.",
|
||||
description="Get detailed information about a specific provider.",
|
||||
responses={200: {"description": "A ProviderInfo object containing the provider's details."}},
|
||||
)
|
||||
async def inspect_provider(
|
||||
request: Annotated[InspectProviderRequest, Depends(get_inspect_provider_request)],
|
||||
) -> ProviderInfo:
|
||||
return await impl.inspect_provider(request)
|
||||
|
||||
return router
|
||||
43
src/llama_stack_api/providers/models.py
Normal file
43
src/llama_stack_api/providers/models.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Providers API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Providers API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.datatypes import HealthResponse
|
||||
from llama_stack_api.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about a registered provider including its configuration and health status."""
|
||||
|
||||
api: str = Field(..., description="The API name this provider implements")
|
||||
provider_id: str = Field(..., description="Unique identifier for the provider")
|
||||
provider_type: str = Field(..., description="The type of provider implementation")
|
||||
config: dict[str, Any] = Field(..., description="Configuration parameters for the provider")
|
||||
health: HealthResponse = Field(..., description="Current health status of the provider")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListProvidersResponse(BaseModel):
|
||||
"""Response containing a list of all available providers."""
|
||||
|
||||
data: list[ProviderInfo] = Field(..., description="List of provider information objects")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InspectProviderRequest(BaseModel):
|
||||
"""Request model for inspecting a provider by ID."""
|
||||
|
||||
provider_id: str = Field(..., description="The ID of the provider to inspect.")
|
||||
Loading…
Add table
Add a link
Reference in a new issue