diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index acc871f01..613c90bd6 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -83,6 +83,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type: j = response.json() if j is None: return None + # print(f"({protocol.__name__}) Returning {j}, type {return_type}") return parse_obj_as(return_type, j) async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any: @@ -102,14 +103,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type: if line.startswith("data:"): data = line[len("data: ") :] try: + data = json.loads(data) if "error" in data: cprint(data, "red") continue - yield parse_obj_as(return_type, json.loads(data)) + yield parse_obj_as(return_type, data) except Exception as e: - print(data) print(f"Error with parsing or validation: {e}") + print(data) def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: webmethod, sig = self.routes[method_name] diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 332efeed8..7b16242cf 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "remote", + "safety": "remote", + "memory": "remote", + "agents": "remote", + }, + id="remote", + marks=pytest.mark.remote, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together"]: + for mark in ["meta_reference", "ollama", "together", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index c667712a7..153ade0da 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def agents_remote() -> ProviderFixture: + return remote_stack_fixture() @pytest.fixture(scope="session") @@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture: ) -AGENTS_FIXTURES = ["meta_reference"] +AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 54c10a42d..5b1fe202a 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -109,7 +109,6 @@ class TestAgents: turn_response = [ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] - assert len(turn_response) > 0 check_event_types(turn_response) diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 9fdf94582..11b0dcb45 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -14,6 +14,9 @@ from pydantic import BaseModel from termcolor import colored from llama_stack.distribution.datatypes import Provider +from llama_stack.providers.datatypes import RemoteProviderConfig + +from .env import get_env_or_fail class ProviderFixture(BaseModel): @@ -21,6 +24,21 @@ class ProviderFixture(BaseModel): provider_data: Optional[Dict[str, Any]] = None +def remote_stack_fixture() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="remote", + provider_type="remote", + config=RemoteProviderConfig( + host=get_env_or_fail("REMOTE_STACK_HOST"), + port=int(get_env_or_fail("REMOTE_STACK_PORT")), + ).model_dump(), + ) + ], + ) + + def pytest_configure(config): config.option.tbstyle = "short" config.option.disable_warnings = True diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 860eea4b2..896acbad8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import ( MetaReferenceInferenceConfig, ) from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -29,6 +29,11 @@ def inference_model(request): return request.config.getoption("--inference-model", None) +@pytest.fixture(scope="session") +def inference_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def inference_meta_reference(inference_model) -> ProviderFixture: inference_model = ( @@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture: ) -INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"] +INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 4a6642e85..adeab8476 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +@pytest.fixture(scope="session") +def memory_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def memory_meta_reference() -> ProviderFixture: return ProviderFixture( @@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] +MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 2d6805b35..16c2a32af 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -6,6 +6,7 @@ import json import os +import tempfile from datetime import datetime from typing import Any, Dict, List, Optional @@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store import CachedDiskDistributionRegistry +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig async def resolve_impls_for_test_v2( @@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2( providers=providers, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls(run_config, get_provider_registry()) + + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) + dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) if provider_data: set_request_provider_data( diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index c5424f8db..fb47b290d 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "remote", + "safety": "remote", + }, + id="remote", + marks=pytest.mark.remote, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together"]: + for mark in ["meta_reference", "ollama", "together", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 463c53d2c..74f8ef503 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +@pytest.fixture(scope="session") +def safety_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def safety_model(request): if hasattr(request, "param"): @@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture: ) -SAFETY_FIXTURES = ["meta_reference", "together"] +SAFETY_FIXTURES = ["meta_reference", "together", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index ddf472737..9a629e85c 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -27,7 +27,7 @@ class TestSafety: for shield in response: assert isinstance(shield, ShieldDefWithProvider) - assert shield.type in [v.value for v in ShieldType] + assert shield.shield_type in [v.value for v in ShieldType] @pytest.mark.asyncio async def test_run_shield(self, safety_stack):