From d479bc5fdc0c2954ae5ba527e8b2d74e06e9d845 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 13 Mar 2025 17:11:49 -0700 Subject: [PATCH] fix --- llama_stack/apis/providers/providers.py | 8 ++------ llama_stack/distribution/providers.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index fd37bd500..3879c7fc4 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Protocol, runtime_checkable +from typing import List, Optional, Protocol, runtime_checkable from pydantic import BaseModel @@ -19,10 +19,6 @@ class ProviderInfo(BaseModel): provider_type: str -class GetProviderResponse(BaseModel): - data: Provider | None - - class ListProvidersResponse(BaseModel): data: List[ProviderInfo] @@ -37,4 +33,4 @@ class Providers(Protocol): async def list_providers(self) -> ListProvidersResponse: ... @webmethod(route="/providers/{provider_id}", method="GET") - async def inspect_provider(self, provider_id: str) -> GetProviderResponse: ... + async def inspect_provider(self, provider_id: str) -> Optional[ProviderInfo]: ... diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 219384900..82b86be1d 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -6,7 +6,12 @@ from pydantic import BaseModel -from llama_stack.apis.providers import GetProviderResponse, ListProvidersResponse, ProviderInfo, Providers +from llama_stack.apis.providers import ( + GetProviderResponse, + ListProvidersResponse, + ProviderInfo, + Providers, +) from .datatypes import StackRunConfig from .stack import redact_sensitive_fields @@ -50,10 +55,13 @@ class ProviderImpl(Providers): async def inspect_provider(self, provider_id: str) -> GetProviderResponse: run_config = self.config.run_config safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) - ret = None - for _, providers in safe_config.providers.items(): + for api, providers in safe_config.providers.items(): for p in providers: if p.provider_id == provider_id: - ret = p + return ProviderInfo( + api=api, + provider_id=p.provider_id, + provider_type=p.provider_type, + ) - return GetProviderResponse(data=ret) + return None