From 678ab29129cb106d96f1612d1ba270a5aebb1af0 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 16 Jan 2025 10:39:42 -0800 Subject: [PATCH] Idiomatic REST API: Inspect (#779) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Since provider list returns a map grouping providers by API, we should not be using data. This PR fixes the types to just be the plain dict, basically reverting back to previous behavior ## Test Plan llama-stack on  fix-provider-list [$] 🅒 stack❯ LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/safety/test_safety.py --- llama_stack/apis/inspect/inspect.py | 12 ++++++++---- llama_stack/distribution/inspect.py | 20 +++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index e2bb98217..9d20c27b3 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Protocol, runtime_checkable +from typing import List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel @@ -38,13 +38,17 @@ class ListProvidersResponse(BaseModel): data: List[ProviderInfo] +class ListRoutesResponse(BaseModel): + data: List[RouteInfo] + + @runtime_checkable class Inspect(Protocol): - @webmethod(route="/providers/list", method="GET") + @webmethod(route="/inspect/providers", method="GET") async def list_providers(self) -> ListProvidersResponse: ... - @webmethod(route="/routes/list", method="GET") - async def list_routes(self) -> Dict[str, List[RouteInfo]]: ... + @webmethod(route="/inspect/routes", method="GET") + async def list_routes(self) -> ListRoutesResponse: ... @webmethod(route="/health", method="GET") async def health(self) -> HealthInfo: ... diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index d275a5c2f..08dfb329e 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -5,13 +5,14 @@ # the root directory of this source tree. from importlib.metadata import version -from typing import Dict, List from pydantic import BaseModel from llama_stack.apis.inspect import ( HealthInfo, Inspect, + ListProvidersResponse, + ListRoutesResponse, ProviderInfo, RouteInfo, VersionInfo, @@ -38,36 +39,37 @@ class DistributionInspectImpl(Inspect): async def initialize(self) -> None: pass - async def list_providers(self) -> Dict[str, List[ProviderInfo]]: + async def list_providers(self) -> ListProvidersResponse: run_config = self.config.run_config - ret = {} + ret = [] for api, providers in run_config.providers.items(): - ret[api] = [ + ret.append( ProviderInfo( provider_id=p.provider_id, provider_type=p.provider_type, ) for p in providers - ] + ) return ret - async def list_routes(self) -> Dict[str, List[RouteInfo]]: + async def list_routes(self) -> ListRoutesResponse: run_config = self.config.run_config - ret = {} + ret = [] all_endpoints = get_all_api_endpoints() for api, endpoints in all_endpoints.items(): providers = run_config.providers.get(api.value, []) - ret[api.value] = [ + ret.append( RouteInfo( route=e.route, method=e.method, provider_types=[p.provider_type for p in providers], ) for e in endpoints - ] + ) + return ret async def health(self) -> HealthInfo: