mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
REST API fixes (#789)
# What does this PR do? Client SDK fixes ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/safety/test_safety.py LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/memory/test_memory.py
This commit is contained in:
parent
cee3816609
commit
12c994b5b2
8 changed files with 51 additions and 43 deletions
|
@ -7196,6 +7196,9 @@
|
|||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_id": {
|
||||
"type": "string"
|
||||
},
|
||||
|
@ -7205,6 +7208,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"api",
|
||||
"provider_id",
|
||||
"provider_type"
|
||||
]
|
||||
|
|
|
@ -1678,11 +1678,14 @@ components:
|
|||
ProviderInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
api:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
type: object
|
||||
|
|
|
@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
|
|
@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect):
|
|||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ret
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect):
|
|||
all_endpoints = get_all_api_endpoints()
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret.append(
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
)
|
||||
|
||||
return ret
|
||||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
|
|
|
@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
|
|||
def agent_config(llama_stack_client):
|
||||
available_models = [
|
||||
model.identifier
|
||||
for model in llama_stack_client.models.list().data
|
||||
for model in llama_stack_client.models.list()
|
||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||
]
|
||||
model_id = available_models[0]
|
||||
print(f"Using model: {model_id}")
|
||||
available_shields = [
|
||||
shield.identifier for shield in llama_stack_client.shields.list().data
|
||||
shield.identifier for shield in llama_stack_client.shields.list()
|
||||
]
|
||||
available_shields = available_shields[:1]
|
||||
print(f"Using shield: {available_shields}")
|
||||
|
|
|
@ -5,10 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||
"remote::ollama": "python_list",
|
||||
"remote::together": "json",
|
||||
|
@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
|
|||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
assert len(providers.inference) > 0
|
||||
return providers.inference[0]["provider_type"]
|
||||
inference_providers = [p for p in providers if p.api == "inference"]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
return inference_providers[0].provider_type
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def text_model_id(llama_stack_client):
|
||||
available_models = [
|
||||
model.identifier
|
||||
for model in llama_stack_client.models.list().data
|
||||
for model in llama_stack_client.models.list()
|
||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||
]
|
||||
assert len(available_models) > 0
|
||||
|
@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
|
|||
def vision_model_id(llama_stack_client):
|
||||
available_models = [
|
||||
model.identifier
|
||||
for model in llama_stack_client.models.list().data
|
||||
for model in llama_stack_client.models.list()
|
||||
if "vision" in model.identifier.lower()
|
||||
]
|
||||
if len(available_models) == 0:
|
||||
|
|
|
@ -7,16 +7,15 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def empty_memory_bank_registry(llama_stack_client):
|
||||
memory_banks = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
for memory_bank_id in memory_banks:
|
||||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
||||
|
@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
|
|||
provider_id="faiss",
|
||||
)
|
||||
memory_banks = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
return memory_banks
|
||||
|
||||
|
@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
|
|||
|
||||
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
|
||||
memory_banks_after_register = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
assert len(memory_banks_after_register) == 0
|
||||
|
||||
|
@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
|
|||
)
|
||||
|
||||
memory_banks_after_register = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
assert memory_banks_after_register == [memory_bank_id]
|
||||
|
||||
|
||||
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
|
||||
memory_banks = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
assert len(memory_banks) == 1
|
||||
|
||||
|
@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
|
|||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
||||
|
||||
memory_banks = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
assert len(memory_banks) == 0
|
||||
|
||||
|
@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query(
|
|||
def test_memory_bank_insert_from_url_and_query(
|
||||
llama_stack_client, empty_memory_bank_registry
|
||||
):
|
||||
providers = llama_stack_client.providers.list().memory
|
||||
providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
|
||||
assert len(providers) > 0
|
||||
|
||||
memory_provider_id = providers[0]["provider_id"]
|
||||
memory_provider_id = providers[0].provider_id
|
||||
memory_bank_id = "test_bank"
|
||||
|
||||
llama_stack_client.memory_banks.register(
|
||||
|
@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query(
|
|||
|
||||
# list to check memory bank is successfully registered
|
||||
available_memory_banks = [
|
||||
memory_bank.identifier
|
||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||
]
|
||||
assert memory_bank_id in available_memory_banks
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import pytest
|
|||
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
|
||||
|
||||
VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
|
||||
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
||||
|
||||
|
@ -30,7 +29,7 @@ def data_url_from_image(file_path):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def available_shields(llama_stack_client):
|
||||
return [shield.identifier for shield in llama_stack_client.shields.list().data]
|
||||
return [shield.identifier for shield in llama_stack_client.shields.list()]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields):
|
|||
@pytest.fixture(scope="session")
|
||||
def model_providers(llama_stack_client):
|
||||
return set(
|
||||
[x["provider_id"] for x in llama_stack_client.providers.list().inference]
|
||||
[
|
||||
x.provider_id
|
||||
for x in llama_stack_client.providers.list()
|
||||
if x.api == "inference"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue