REST API fixes (#789)

# What does this PR do?

Client SDK fixes

## Test Plan


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/safety/test_safety.py


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/memory/test_memory.py
This commit is contained in:
Dinesh Yeduguru 2025-01-16 13:47:08 -08:00 committed by GitHub
parent cee3816609
commit 12c994b5b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 51 additions and 43 deletions

View file

@ -7196,6 +7196,9 @@
"ProviderInfo": { "ProviderInfo": {
"type": "object", "type": "object",
"properties": { "properties": {
"api": {
"type": "string"
},
"provider_id": { "provider_id": {
"type": "string" "type": "string"
}, },
@ -7205,6 +7208,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"api",
"provider_id", "provider_id",
"provider_type" "provider_type"
] ]

View file

@ -1678,11 +1678,14 @@ components:
ProviderInfo: ProviderInfo:
additionalProperties: false additionalProperties: false
properties: properties:
api:
type: string
provider_id: provider_id:
type: string type: string
provider_type: provider_type:
type: string type: string
required: required:
- api
- provider_id - provider_id
- provider_type - provider_type
type: object type: object

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel
@json_schema_type @json_schema_type
class ProviderInfo(BaseModel): class ProviderInfo(BaseModel):
api: str
provider_id: str provider_id: str
provider_type: str provider_type: str

View file

@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect):
ret = [] ret = []
for api, providers in run_config.providers.items(): for api, providers in run_config.providers.items():
ret.append( ret.extend(
ProviderInfo( [
provider_id=p.provider_id, ProviderInfo(
provider_type=p.provider_type, api=api,
) provider_id=p.provider_id,
for p in providers provider_type=p.provider_type,
)
for p in providers
]
) )
return ret return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse: async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config run_config = self.config.run_config
@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect):
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, []) providers = run_config.providers.get(api.value, [])
ret.append( ret.extend(
RouteInfo( [
route=e.route, RouteInfo(
method=e.method, route=e.route,
provider_types=[p.provider_type for p in providers], method=e.method,
) provider_types=[p.provider_type for p in providers],
for e in endpoints )
for e in endpoints
]
) )
return ret return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo: async def health(self) -> HealthInfo:
return HealthInfo(status="OK") return HealthInfo(status="OK")

View file

@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
def agent_config(llama_stack_client): def agent_config(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list().data for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier if model.identifier.startswith("meta-llama") and "405" not in model.identifier
] ]
model_id = available_models[0] model_id = available_models[0]
print(f"Using model: {model_id}") print(f"Using model: {model_id}")
available_shields = [ available_shields = [
shield.identifier for shield in llama_stack_client.shields.list().data shield.identifier for shield in llama_stack_client.shields.list()
] ]
available_shields = available_shields[:1] available_shields = available_shields[:1]
print(f"Using shield: {available_shields}") print(f"Using shield: {available_shields}")

View file

@ -5,10 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
PROVIDER_TOOL_PROMPT_FORMAT = { PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list", "remote::ollama": "python_list",
"remote::together": "json", "remote::together": "json",
@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client): def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list() providers = llama_stack_client.providers.list()
assert len(providers.inference) > 0 inference_providers = [p for p in providers if p.api == "inference"]
return providers.inference[0]["provider_type"] assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def text_model_id(llama_stack_client): def text_model_id(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list().data for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier if model.identifier.startswith("meta-llama") and "405" not in model.identifier
] ]
assert len(available_models) > 0 assert len(available_models) > 0
@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
def vision_model_id(llama_stack_client): def vision_model_id(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list().data for model in llama_stack_client.models.list()
if "vision" in model.identifier.lower() if "vision" in model.identifier.lower()
] ]
if len(available_models) == 0: if len(available_models) == 0:

View file

@ -7,16 +7,15 @@
import random import random
import pytest import pytest
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client): def empty_memory_bank_registry(llama_stack_client):
memory_banks = [ memory_banks = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
for memory_bank_id in memory_banks: for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
provider_id="faiss", provider_id="faiss",
) )
memory_banks = [ memory_banks = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
return memory_banks return memory_banks
@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry): def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [ memory_banks_after_register = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks_after_register) == 0 assert len(memory_banks_after_register) == 0
@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
) )
memory_banks_after_register = [ memory_banks_after_register = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert memory_banks_after_register == [memory_bank_id] assert memory_banks_after_register == [memory_bank_id]
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry): def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [ memory_banks = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks) == 1 assert len(memory_banks) == 1
@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
memory_banks = [ memory_banks = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks) == 0 assert len(memory_banks) == 0
@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query(
def test_memory_bank_insert_from_url_and_query( def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry llama_stack_client, empty_memory_bank_registry
): ):
providers = llama_stack_client.providers.list().memory providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
assert len(providers) > 0 assert len(providers) > 0
memory_provider_id = providers[0]["provider_id"] memory_provider_id = providers[0].provider_id
memory_bank_id = "test_bank" memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register( llama_stack_client.memory_banks.register(
@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query(
# list to check memory bank is successfully registered # list to check memory bank is successfully registered
available_memory_banks = [ available_memory_banks = [
memory_bank.identifier memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert memory_bank_id in available_memory_banks assert memory_bank_id in available_memory_banks

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
VISION_SHIELD_ENABLED_PROVIDERS = {"together"} VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"} CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
@ -30,7 +29,7 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def available_shields(llama_stack_client): def available_shields(llama_stack_client):
return [shield.identifier for shield in llama_stack_client.shields.list().data] return [shield.identifier for shield in llama_stack_client.shields.list()]
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def model_providers(llama_stack_client): def model_providers(llama_stack_client):
return set( return set(
[x["provider_id"] for x in llama_stack_client.providers.list().inference] [
x.provider_id
for x in llama_stack_client.providers.list()
if x.api == "inference"
]
) )