From 33b096cc21e48910cf05f0c3e513032adb99fa84 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 13 Mar 2025 19:56:32 -0700 Subject: [PATCH] fix: OpenAPI with provider get (#1627) # What does this PR do? - https://github.com/meta-llama/llama-stack/pull/1429 introduces GetProviderResponse in OpenAPI, which is not needed, and not correctly defined. cc @cdoern [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama-stack-client providers list ``` image [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 189 ++++++++++++------------ docs/_static/llama-stack-spec.yaml | 124 ++++++++-------- llama_stack/apis/providers/providers.py | 10 +- llama_stack/distribution/providers.py | 22 +-- 4 files changed, 166 insertions(+), 179 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e62f66bd6..b5e4097d9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2151,6 +2151,48 @@ } } }, + "/v1/providers/{provider_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderInfo" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Providers" + ], + "description": "", + "parameters": [ + { + "name": "provider_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/tool-runtime/invoke": { "post": { "responses": { @@ -2643,80 +2685,6 @@ } }, "/v1/providers": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListProvidersResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Providers" - ], - "description": "", - "parameters": [] - } - }, - "/v1/providers/{provider_id}": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/GetProviderResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Providers" - ], - "description": "", - "parameters": [ - { - "name": "provider_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - } - ] - }, - "/v1/inspect/providers": { "get": { "responses": { "200": { @@ -7986,6 +7954,53 @@ ], "title": "InsertChunksRequest" }, + "ProviderInfo": { + "type": "object", + "properties": { + "api": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "provider_type": { + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "api", + "provider_id", + "provider_type", + "config" + ], + "title": "ProviderInfo" + }, "InvokeToolRequest": { "type": "object", "properties": { @@ -8198,27 +8213,6 @@ ], "title": "ListModelsResponse" }, - "ProviderInfo": { - "type": "object", - "properties": { - "api": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "provider_type": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "api", - "provider_id", - "provider_type" - ], - "title": "ProviderInfo" - }, "ListProvidersResponse": { "type": "object", "properties": { @@ -10219,6 +10213,10 @@ { "name": "PostTraining (Coming Soon)" }, + { + "name": "Providers", + "x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations." + }, { "name": "Safety" }, @@ -10265,6 +10263,7 @@ "Inspect", "Models", "PostTraining (Coming Soon)", + "Providers", "Safety", "Scoring", "ScoringFunctions", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index cb31848ee..bf2343ede 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1444,6 +1444,34 @@ paths: schema: $ref: '#/components/schemas/InsertChunksRequest' required: true + /v1/providers/{provider_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ProviderInfo' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Providers + description: '' + parameters: + - name: provider_id + in: path + required: true + schema: + type: string /v1/tool-runtime/invoke: post: responses: @@ -1783,57 +1811,6 @@ paths: $ref: '#/components/schemas/RegisterModelRequest' required: true /v1/providers: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/ListProvidersResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Providers - description: '' - parameters: [] - /v1/providers/{provider_id}: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/GetProviderResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Providers - description: '' - parameters: - - name: provider_id - in: path - required: true - schema: - type: string - /v1/inspect/providers: get: responses: '200': @@ -5460,6 +5437,32 @@ components: - vector_db_id - chunks title: InsertChunksRequest + ProviderInfo: + type: object + properties: + api: + type: string + provider_id: + type: string + provider_type: + type: string + config: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - api + - provider_id + - provider_type + - config + title: ProviderInfo InvokeToolRequest: type: object properties: @@ -5595,21 +5598,6 @@ components: required: - data title: ListModelsResponse - ProviderInfo: - type: object - properties: - api: - type: string - provider_id: - type: string - provider_type: - type: string - additionalProperties: false - required: - - api - - provider_id - - provider_type - title: ProviderInfo ListProvidersResponse: type: object properties: @@ -6883,6 +6871,9 @@ tags: - name: Inspect - name: Models - name: PostTraining (Coming Soon) + - name: Providers + x-displayName: >- + Providers API for inspecting, listing, and modifying providers and their configurations. - name: Safety - name: Scoring - name: ScoringFunctions @@ -6907,6 +6898,7 @@ x-tagGroups: - Inspect - Models - PostTraining (Coming Soon) + - Providers - Safety - Scoring - ScoringFunctions diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index fd37bd500..83d03d7c1 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Protocol, runtime_checkable +from typing import Any, Dict, List, Protocol, runtime_checkable from pydantic import BaseModel -from llama_stack.distribution.datatypes import Provider from llama_stack.schema_utils import json_schema_type, webmethod @@ -17,10 +16,7 @@ class ProviderInfo(BaseModel): api: str provider_id: str provider_type: str - - -class GetProviderResponse(BaseModel): - data: Provider | None + config: Dict[str, Any] class ListProvidersResponse(BaseModel): @@ -37,4 +33,4 @@ class Providers(Protocol): async def list_providers(self) -> ListProvidersResponse: ... @webmethod(route="/providers/{provider_id}", method="GET") - async def inspect_provider(self, provider_id: str) -> GetProviderResponse: ... + async def inspect_provider(self, provider_id: str) -> ProviderInfo: ... diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 219384900..fb2476767 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from pydantic import BaseModel -from llama_stack.apis.providers import GetProviderResponse, ListProvidersResponse, ProviderInfo, Providers +from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from .datatypes import StackRunConfig from .stack import redact_sensitive_fields @@ -32,14 +33,16 @@ class ProviderImpl(Providers): async def list_providers(self) -> ListProvidersResponse: run_config = self.config.run_config + safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) ret = [] - for api, providers in run_config.providers.items(): + for api, providers in safe_config.providers.items(): ret.extend( [ ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, + config=p.config, ) for p in providers ] @@ -47,13 +50,10 @@ class ProviderImpl(Providers): return ListProvidersResponse(data=ret) - async def inspect_provider(self, provider_id: str) -> GetProviderResponse: - run_config = self.config.run_config - safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) - ret = None - for _, providers in safe_config.providers.items(): - for p in providers: - if p.provider_id == provider_id: - ret = p + async def inspect_provider(self, provider_id: str) -> ProviderInfo: + all_providers = await self.list_providers() + for p in all_providers.data: + if p.provider_id == provider_id: + return p - return GetProviderResponse(data=ret) + raise ValueError(f"Provider {provider_id} not found")