diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 750dce798..38cabdd3e 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -7196,6 +7196,9 @@ "ProviderInfo": { "type": "object", "properties": { + "api": { + "type": "string" + }, "provider_id": { "type": "string" }, @@ -7205,6 +7208,7 @@ }, "additionalProperties": false, "required": [ + "api", "provider_id", "provider_type" ] diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 9d5f9cd60..75bc25e94 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1678,11 +1678,14 @@ components: ProviderInfo: additionalProperties: false properties: + api: + type: string provider_id: type: string provider_type: type: string required: + - api - provider_id - provider_type type: object diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 9d20c27b3..cd51469c1 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -12,6 +12,7 @@ from pydantic import BaseModel @json_schema_type class ProviderInfo(BaseModel): + api: str provider_id: str provider_type: str diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 08dfb329e..b7ee4a219 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect): ret = [] for api, providers in run_config.providers.items(): - ret.append( - ProviderInfo( - provider_id=p.provider_id, - provider_type=p.provider_type, - ) - for p in providers + ret.extend( + [ + ProviderInfo( + api=api, + provider_id=p.provider_id, + provider_type=p.provider_type, + ) + for p in providers + ] ) - return ret + return ListProvidersResponse(data=ret) async def list_routes(self) -> ListRoutesResponse: run_config = self.config.run_config @@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect): all_endpoints = get_all_api_endpoints() for api, endpoints in all_endpoints.items(): providers = run_config.providers.get(api.value, []) - ret.append( - RouteInfo( - route=e.route, - method=e.method, - provider_types=[p.provider_type for p in providers], - ) - for e in endpoints + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[p.provider_type for p in providers], + ) + for e in endpoints + ] ) - return ret + return ListRoutesResponse(data=ret) async def health(self) -> HealthInfo: return HealthInfo(status="OK") diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 747b64dd1..19a4064a0 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -83,13 +83,13 @@ class TestClientTool(ClientTool): def agent_config(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list().data + for model in llama_stack_client.models.list() if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] model_id = available_models[0] print(f"Using model: {model_id}") available_shields = [ - shield.identifier for shield in llama_stack_client.shields.list().data + shield.identifier for shield in llama_stack_client.shields.list() ] available_shields = available_shields[:1] print(f"Using shield: {available_shields}") diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 5191a3f7f..671a37926 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -5,10 +5,8 @@ # the root directory of this source tree. import pytest - from pydantic import BaseModel - PROVIDER_TOOL_PROMPT_FORMAT = { "remote::ollama": "python_list", "remote::together": "json", @@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type): @pytest.fixture(scope="session") def inference_provider_type(llama_stack_client): providers = llama_stack_client.providers.list() - assert len(providers.inference) > 0 - return providers.inference[0]["provider_type"] + inference_providers = [p for p in providers if p.api == "inference"] + assert len(inference_providers) > 0, "No inference providers found" + return inference_providers[0].provider_type @pytest.fixture(scope="session") def text_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list().data + for model in llama_stack_client.models.list() if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] assert len(available_models) > 0 @@ -47,7 +46,7 @@ def text_model_id(llama_stack_client): def vision_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list().data + for model in llama_stack_client.models.list() if "vision" in model.identifier.lower() ] if len(available_models) == 0: diff --git a/tests/client-sdk/memory/test_memory.py b/tests/client-sdk/memory/test_memory.py index a5f154fda..1e9b34355 100644 --- a/tests/client-sdk/memory/test_memory.py +++ b/tests/client-sdk/memory/test_memory.py @@ -7,16 +7,15 @@ import random import pytest -from llama_stack.apis.memory import MemoryBankDocument +from llama_stack.apis.memory import MemoryBankDocument from llama_stack_client.types.memory_insert_params import Document @pytest.fixture(scope="function") def empty_memory_bank_registry(llama_stack_client): memory_banks = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] for memory_bank_id in memory_banks: llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) @@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi provider_id="faiss", ) memory_banks = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] return memory_banks @@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry): def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry): memory_banks_after_register = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] assert len(memory_banks_after_register) == 0 @@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry): ) memory_banks_after_register = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] assert memory_banks_after_register == [memory_bank_id] def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry): memory_banks = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] assert len(memory_banks) == 1 @@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) memory_banks = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] assert len(memory_banks) == 0 @@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query( def test_memory_bank_insert_from_url_and_query( llama_stack_client, empty_memory_bank_registry ): - providers = llama_stack_client.providers.list().memory + providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"] assert len(providers) > 0 - memory_provider_id = providers[0]["provider_id"] + memory_provider_id = providers[0].provider_id memory_bank_id = "test_bank" llama_stack_client.memory_banks.register( @@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query( # list to check memory bank is successfully registered available_memory_banks = [ - memory_bank.identifier - for memory_bank in llama_stack_client.memory_banks.list().data + memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() ] assert memory_bank_id in available_memory_banks diff --git a/tests/client-sdk/safety/test_safety.py b/tests/client-sdk/safety/test_safety.py index 2d79bda5e..6af417a09 100644 --- a/tests/client-sdk/safety/test_safety.py +++ b/tests/client-sdk/safety/test_safety.py @@ -11,7 +11,6 @@ import pytest from llama_stack.apis.safety import ViolationLevel - VISION_SHIELD_ENABLED_PROVIDERS = {"together"} CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"} @@ -30,7 +29,7 @@ def data_url_from_image(file_path): @pytest.fixture(scope="session") def available_shields(llama_stack_client): - return [shield.identifier for shield in llama_stack_client.shields.list().data] + return [shield.identifier for shield in llama_stack_client.shields.list()] @pytest.fixture(scope="session") @@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields): @pytest.fixture(scope="session") def model_providers(llama_stack_client): return set( - [x["provider_id"] for x in llama_stack_client.providers.list().inference] + [ + x.provider_id + for x in llama_stack_client.providers.list() + if x.api == "inference" + ] )