REST API fixes (#789)

# What does this PR do?

Client SDK fixes

## Test Plan


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/safety/test_safety.py


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/memory/test_memory.py
This commit is contained in:
Dinesh Yeduguru 2025-01-16 13:47:08 -08:00 committed by GitHub
parent cee3816609
commit 12c994b5b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 51 additions and 43 deletions

View file

@ -5,10 +5,8 @@
# the root directory of this source tree.
import pytest
from pydantic import BaseModel
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list",
"remote::together": "json",
@ -28,15 +26,16 @@ def provider_tool_format(inference_provider_type):
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
assert len(providers.inference) > 0
return providers.inference[0]["provider_type"]
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def text_model_id(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list().data
for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
assert len(available_models) > 0
@ -47,7 +46,7 @@ def text_model_id(llama_stack_client):
def vision_model_id(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list().data
for model in llama_stack_client.models.list()
if "vision" in model.identifier.lower()
]
if len(available_models) == 0: