REST API fixes (#789)

# What does this PR do?

Client SDK fixes

## Test Plan


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/safety/test_safety.py


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/memory/test_memory.py
This commit is contained in:
Dinesh Yeduguru 2025-01-16 13:47:08 -08:00 committed by GitHub
parent cee3816609
commit 12c994b5b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 51 additions and 43 deletions

View file

@ -7196,6 +7196,9 @@
"ProviderInfo": {
"type": "object",
"properties": {
"api": {
"type": "string"
},
"provider_id": {
"type": "string"
},
@ -7205,6 +7208,7 @@
},
"additionalProperties": false,
"required": [
"api",
"provider_id",
"provider_type"
]

View file

@ -1678,11 +1678,14 @@ components:
ProviderInfo:
additionalProperties: false
properties:
api:
type: string
provider_id:
type: string
provider_type:
type: string
required:
- api
- provider_id
- provider_type
type: object

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str

View file

@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect):
ret = []
for api, providers in run_config.providers.items():
ret.append(
ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ret
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config
@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect):
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret.append(
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]
)
return ret
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")

View file

@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
def agent_config(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list().data
for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
model_id = available_models[0]
print(f"Using model: {model_id}")
available_shields = [
shield.identifier for shield in llama_stack_client.shields.list().data
shield.identifier for shield in llama_stack_client.shields.list()
]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")

View file

@ -5,10 +5,8 @@
# the root directory of this source tree.
import pytest
from pydantic import BaseModel
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list",
"remote::together": "json",
@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
assert len(providers.inference) > 0
return providers.inference[0]["provider_type"]
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def text_model_id(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list().data
for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
assert len(available_models) > 0
@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
def vision_model_id(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list().data
for model in llama_stack_client.models.list()
if "vision" in model.identifier.lower()
]
if len(available_models) == 0:

View file

@ -7,16 +7,15 @@
import random
import pytest
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack_client.types.memory_insert_params import Document
@pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client):
memory_banks = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
provider_id="faiss",
)
memory_banks = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
return memory_banks
@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks_after_register) == 0
@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
)
memory_banks_after_register = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_banks_after_register == [memory_bank_id]
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 1
@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
memory_banks = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 0
@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query(
def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry
):
providers = llama_stack_client.providers.list().memory
providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
assert len(providers) > 0
memory_provider_id = providers[0]["provider_id"]
memory_provider_id = providers[0].provider_id
memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register(
@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query(
# list to check memory bank is successfully registered
available_memory_banks = [
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_bank_id in available_memory_banks

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.apis.safety import ViolationLevel
VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
@ -30,7 +29,7 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session")
def available_shields(llama_stack_client):
return [shield.identifier for shield in llama_stack_client.shields.list().data]
return [shield.identifier for shield in llama_stack_client.shields.list()]
@pytest.fixture(scope="session")
@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return set(
[x["provider_id"] for x in llama_stack_client.providers.list().inference]
[
x.provider_id
for x in llama_stack_client.providers.list()
if x.api == "inference"
]
)