mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
REST API fixes (#789)
# What does this PR do? Client SDK fixes ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/safety/test_safety.py LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/memory/test_memory.py
This commit is contained in:
parent
cee3816609
commit
12c994b5b2
8 changed files with 51 additions and 43 deletions
|
@ -7196,6 +7196,9 @@
|
||||||
"ProviderInfo": {
|
"ProviderInfo": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"api": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -7205,6 +7208,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
|
"api",
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"provider_type"
|
"provider_type"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1678,11 +1678,14 @@ components:
|
||||||
ProviderInfo:
|
ProviderInfo:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
api:
|
||||||
|
type: string
|
||||||
provider_id:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
provider_type:
|
provider_type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
|
- api
|
||||||
- provider_id
|
- provider_id
|
||||||
- provider_type
|
- provider_type
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -12,6 +12,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderInfo(BaseModel):
|
class ProviderInfo(BaseModel):
|
||||||
|
api: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
|
|
||||||
|
|
|
@ -44,15 +44,18 @@ class DistributionInspectImpl(Inspect):
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for api, providers in run_config.providers.items():
|
for api, providers in run_config.providers.items():
|
||||||
ret.append(
|
ret.extend(
|
||||||
ProviderInfo(
|
[
|
||||||
provider_id=p.provider_id,
|
ProviderInfo(
|
||||||
provider_type=p.provider_type,
|
api=api,
|
||||||
)
|
provider_id=p.provider_id,
|
||||||
for p in providers
|
provider_type=p.provider_type,
|
||||||
|
)
|
||||||
|
for p in providers
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ListProvidersResponse(data=ret)
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
run_config = self.config.run_config
|
run_config = self.config.run_config
|
||||||
|
@ -61,16 +64,18 @@ class DistributionInspectImpl(Inspect):
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
providers = run_config.providers.get(api.value, [])
|
providers = run_config.providers.get(api.value, [])
|
||||||
ret.append(
|
ret.extend(
|
||||||
RouteInfo(
|
[
|
||||||
route=e.route,
|
RouteInfo(
|
||||||
method=e.method,
|
route=e.route,
|
||||||
provider_types=[p.provider_type for p in providers],
|
method=e.method,
|
||||||
)
|
provider_types=[p.provider_type for p in providers],
|
||||||
for e in endpoints
|
)
|
||||||
|
for e in endpoints
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ListRoutesResponse(data=ret)
|
||||||
|
|
||||||
async def health(self) -> HealthInfo:
|
async def health(self) -> HealthInfo:
|
||||||
return HealthInfo(status="OK")
|
return HealthInfo(status="OK")
|
||||||
|
|
|
@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
|
||||||
def agent_config(llama_stack_client):
|
def agent_config(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list().data
|
for model in llama_stack_client.models.list()
|
||||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
model_id = available_models[0]
|
model_id = available_models[0]
|
||||||
print(f"Using model: {model_id}")
|
print(f"Using model: {model_id}")
|
||||||
available_shields = [
|
available_shields = [
|
||||||
shield.identifier for shield in llama_stack_client.shields.list().data
|
shield.identifier for shield in llama_stack_client.shields.list()
|
||||||
]
|
]
|
||||||
available_shields = available_shields[:1]
|
available_shields = available_shields[:1]
|
||||||
print(f"Using shield: {available_shields}")
|
print(f"Using shield: {available_shields}")
|
||||||
|
|
|
@ -5,10 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
"remote::ollama": "python_list",
|
"remote::ollama": "python_list",
|
||||||
"remote::together": "json",
|
"remote::together": "json",
|
||||||
|
@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_provider_type(llama_stack_client):
|
def inference_provider_type(llama_stack_client):
|
||||||
providers = llama_stack_client.providers.list()
|
providers = llama_stack_client.providers.list()
|
||||||
assert len(providers.inference) > 0
|
inference_providers = [p for p in providers if p.api == "inference"]
|
||||||
return providers.inference[0]["provider_type"]
|
assert len(inference_providers) > 0, "No inference providers found"
|
||||||
|
return inference_providers[0].provider_type
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def text_model_id(llama_stack_client):
|
def text_model_id(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list().data
|
for model in llama_stack_client.models.list()
|
||||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
assert len(available_models) > 0
|
assert len(available_models) > 0
|
||||||
|
@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
|
||||||
def vision_model_id(llama_stack_client):
|
def vision_model_id(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list().data
|
for model in llama_stack_client.models.list()
|
||||||
if "vision" in model.identifier.lower()
|
if "vision" in model.identifier.lower()
|
||||||
]
|
]
|
||||||
if len(available_models) == 0:
|
if len(available_models) == 0:
|
||||||
|
|
|
@ -7,16 +7,15 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack.apis.memory import MemoryBankDocument
|
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import MemoryBankDocument
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def empty_memory_bank_registry(llama_stack_client):
|
def empty_memory_bank_registry(llama_stack_client):
|
||||||
memory_banks = [
|
memory_banks = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
for memory_bank_id in memory_banks:
|
for memory_bank_id in memory_banks:
|
||||||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
||||||
|
@ -36,8 +35,7 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
memory_banks = [
|
memory_banks = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
return memory_banks
|
return memory_banks
|
||||||
|
|
||||||
|
@ -106,8 +104,7 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
|
||||||
|
|
||||||
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
|
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
|
||||||
memory_banks_after_register = [
|
memory_banks_after_register = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
assert len(memory_banks_after_register) == 0
|
assert len(memory_banks_after_register) == 0
|
||||||
|
|
||||||
|
@ -127,16 +124,14 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_banks_after_register = [
|
memory_banks_after_register = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
assert memory_banks_after_register == [memory_bank_id]
|
assert memory_banks_after_register == [memory_bank_id]
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
|
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
|
||||||
memory_banks = [
|
memory_banks = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
assert len(memory_banks) == 1
|
assert len(memory_banks) == 1
|
||||||
|
|
||||||
|
@ -144,8 +139,7 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
|
||||||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
||||||
|
|
||||||
memory_banks = [
|
memory_banks = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
assert len(memory_banks) == 0
|
assert len(memory_banks) == 0
|
||||||
|
|
||||||
|
@ -201,10 +195,10 @@ def test_memory_bank_insert_inline_and_query(
|
||||||
def test_memory_bank_insert_from_url_and_query(
|
def test_memory_bank_insert_from_url_and_query(
|
||||||
llama_stack_client, empty_memory_bank_registry
|
llama_stack_client, empty_memory_bank_registry
|
||||||
):
|
):
|
||||||
providers = llama_stack_client.providers.list().memory
|
providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
|
||||||
assert len(providers) > 0
|
assert len(providers) > 0
|
||||||
|
|
||||||
memory_provider_id = providers[0]["provider_id"]
|
memory_provider_id = providers[0].provider_id
|
||||||
memory_bank_id = "test_bank"
|
memory_bank_id = "test_bank"
|
||||||
|
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.memory_banks.register(
|
||||||
|
@ -220,8 +214,7 @@ def test_memory_bank_insert_from_url_and_query(
|
||||||
|
|
||||||
# list to check memory bank is successfully registered
|
# list to check memory bank is successfully registered
|
||||||
available_memory_banks = [
|
available_memory_banks = [
|
||||||
memory_bank.identifier
|
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
||||||
for memory_bank in llama_stack_client.memory_banks.list().data
|
|
||||||
]
|
]
|
||||||
assert memory_bank_id in available_memory_banks
|
assert memory_bank_id in available_memory_banks
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
|
|
||||||
|
|
||||||
VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
|
VISION_SHIELD_ENABLED_PROVIDERS = {"together"}
|
||||||
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
||||||
|
|
||||||
|
@ -30,7 +29,7 @@ def data_url_from_image(file_path):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def available_shields(llama_stack_client):
|
def available_shields(llama_stack_client):
|
||||||
return [shield.identifier for shield in llama_stack_client.shields.list().data]
|
return [shield.identifier for shield in llama_stack_client.shields.list()]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -54,7 +53,11 @@ def code_scanner_shield_id(available_shields):
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def model_providers(llama_stack_client):
|
def model_providers(llama_stack_client):
|
||||||
return set(
|
return set(
|
||||||
[x["provider_id"] for x in llama_stack_client.providers.list().inference]
|
[
|
||||||
|
x.provider_id
|
||||||
|
for x in llama_stack_client.providers.list()
|
||||||
|
if x.api == "inference"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue