This commit is contained in:
Xi Yan 2025-03-13 17:11:49 -07:00
parent 98b1b15e0f
commit d479bc5fdc
2 changed files with 15 additions and 11 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, runtime_checkable
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
@ -19,10 +19,6 @@ class ProviderInfo(BaseModel):
provider_type: str
class GetProviderResponse(BaseModel):
data: Provider | None
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@ -37,4 +33,4 @@ class Providers(Protocol):
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> GetProviderResponse: ...
async def inspect_provider(self, provider_id: str) -> Optional[ProviderInfo]: ...

View file

@ -6,7 +6,12 @@
from pydantic import BaseModel
from llama_stack.apis.providers import GetProviderResponse, ListProvidersResponse, ProviderInfo, Providers
from llama_stack.apis.providers import (
GetProviderResponse,
ListProvidersResponse,
ProviderInfo,
Providers,
)
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
@ -50,10 +55,13 @@ class ProviderImpl(Providers):
async def inspect_provider(self, provider_id: str) -> GetProviderResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = None
for _, providers in safe_config.providers.items():
for api, providers in safe_config.providers.items():
for p in providers:
if p.provider_id == provider_id:
ret = p
return ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
return GetProviderResponse(data=ret)
return None