diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 750dce798..38cabdd3e 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -7196,6 +7196,9 @@
"ProviderInfo": {
"type": "object",
"properties": {
+ "api": {
+ "type": "string"
+ },
"provider_id": {
"type": "string"
},
@@ -7205,6 +7208,7 @@
},
"additionalProperties": false,
"required": [
+ "api",
"provider_id",
"provider_type"
]
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 9d5f9cd60..75bc25e94 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -1678,11 +1678,14 @@ components:
ProviderInfo:
additionalProperties: false
properties:
+ api:
+ type: string
provider_id:
type: string
provider_type:
type: string
required:
+ - api
- provider_id
- provider_type
type: object
diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py
index 9d20c27b3..cd51469c1 100644
--- a/llama_stack/apis/inspect/inspect.py
+++ b/llama_stack/apis/inspect/inspect.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
+ api: str
provider_id: str
provider_type: str
diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py
index 08dfb329e..b7ee4a219 100644
--- a/llama_stack/distribution/inspect.py
+++ b/llama_stack/distribution/inspect.py
@@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect):
ret = []
for api, providers in run_config.providers.items():
- ret.append(
- ProviderInfo(
- provider_id=p.provider_id,
- provider_type=p.provider_type,
- )
- for p in providers
+ ret.extend(
+ [
+ ProviderInfo(
+ api=api,
+ provider_id=p.provider_id,
+ provider_type=p.provider_type,
+ )
+ for p in providers
+ ]
)
- return ret
+ return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config
@@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect):
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
- ret.append(
- RouteInfo(
- route=e.route,
- method=e.method,
- provider_types=[p.provider_type for p in providers],
- )
- for e in endpoints
+ ret.extend(
+ [
+ RouteInfo(
+ route=e.route,
+ method=e.method,
+ provider_types=[p.provider_type for p in providers],
+ )
+ for e in endpoints
+ ]
)
- return ret
+ return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index 747b64dd1..19a4064a0 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
def agent_config(llama_stack_client):
available_models = [
model.identifier
- for model in llama_stack_client.models.list().data
+ for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
model_id = available_models[0]
print(f"Using model: {model_id}")
available_shields = [
- shield.identifier for shield in llama_stack_client.shields.list().data
+ shield.identifier for shield in llama_stack_client.shields.list()
]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")
diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py
index 5191a3f7f..671a37926 100644
--- a/tests/client-sdk/inference/test_inference.py
+++ b/tests/client-sdk/inference/test_inference.py
@@ -5,10 +5,8 @@
# the root directory of this source tree.
import pytest
-
from pydantic import BaseModel
-
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list",
"remote::together": "json",
@@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
- assert len(providers.inference) > 0
- return providers.inference[0]["provider_type"]
+ inference_providers = [p for p in providers if p.api == "inference"]
+ assert len(inference_providers) > 0, "No inference providers found"
+ return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def text_model_id(llama_stack_client):
available_models = [
model.identifier
- for model in llama_stack_client.models.list().data
+ for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
assert len(available_models) > 0
@@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
def vision_model_id(llama_stack_client):
available_models = [
model.identifier
- for model in llama_stack_client.models.list().data
+ for model in llama_stack_client.models.list()
if "vision" in model.identifier.lower()
]
if len(available_models) == 0:
diff --git a/tests/client-sdk/memory/test_memory.py b/tests/client-sdk/memory/test_memory.py
index a5f154fda..1e9b34355 100644
--- a/tests/client-sdk/memory/test_memory.py
+++ b/tests/client-sdk/memory/test_memory.py
@@ -7,16 +7,15 @@
import random
import pytest
-from llama_stack.apis.memory import MemoryBankDocument
+from llama_stack.apis.memory import MemoryBankDocument
from llama_stack_client.types.memory_insert_params import Document
@pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client):
memory_banks = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
@@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
provider_id="faiss",
)
memory_banks = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
return memory_banks
@@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks_after_register) == 0
@@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
)
memory_banks_after_register = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_banks_after_register == [memory_bank_id]
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 1
@@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
memory_banks = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 0
@@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query(
def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry
):
- providers = llama_stack_client.providers.list().memory
+ providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
assert len(providers) > 0
- memory_provider_id = providers[0]["provider_id"]
+ memory_provider_id = providers[0].provider_id
memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register(
@@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query(
# list to check memory bank is successfully registered
available_memory_banks = [
- memory_bank.identifier
- for memory_bank in llama_stack_client.memory_banks.list().data
+ memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_bank_id in available_memory_banks
diff --git a/tests/client-sdk/safety/test_safety.py b/tests/client-sdk/safety/test_safety.py
index 2d79bda5e..6af417a09 100644
--- a/tests/client-sdk/safety/test_safety.py
+++ b/tests/client-sdk/safety/test_safety.py
@@ -11,7 +11,6 @@ import pytest
from llama_stack.apis.safety import ViolationLevel
-
VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
@@ -30,7 +29,7 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session")
def available_shields(llama_stack_client):
- return [shield.identifier for shield in llama_stack_client.shields.list().data]
+ return [shield.identifier for shield in llama_stack_client.shields.list()]
@pytest.fixture(scope="session")
@@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return set(
- [x["provider_id"] for x in llama_stack_client.providers.list().inference]
+ [
+ x.provider_id
+ for x in llama_stack_client.providers.list()
+ if x.api == "inference"
+ ]
)