add support for remote providers in tests

This commit is contained in:
Ashwin Bharambe 2024-11-04 19:57:40 -08:00
parent 0763a0b85f
commit 7cf4c905f3
11 changed files with 79 additions and 15 deletions

View file

@ -83,6 +83,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
j = response.json() j = response.json()
if j is None: if j is None:
return None return None
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
return parse_obj_as(return_type, j) return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any: async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
@ -102,14 +103,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
data = json.loads(data)
if "error" in data: if "error" in data:
cprint(data, "red") cprint(data, "red")
continue continue
yield parse_obj_as(return_type, json.loads(data)) yield parse_obj_as(return_type, data)
except Exception as e: except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
print(data)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name] webmethod, sig = self.routes[method_name]

View file

@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together", id="together",
marks=pytest.mark.together, marks=pytest.mark.together,
), ),
pytest.param(
{
"inference": "remote",
"safety": "remote",
"memory": "remote",
"agents": "remote",
},
id="remote",
marks=pytest.mark.remote,
),
] ]
def pytest_configure(config): def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together"]: for mark in ["meta_reference", "ollama", "together", "remote"]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
f"{mark}: marks tests as {mark} specific", f"{mark}: marks tests as {mark} specific",

View file

@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def agents_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture:
) )
AGENTS_FIXTURES = ["meta_reference"] AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -109,7 +109,6 @@ class TestAgents:
turn_response = [ turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
check_event_types(turn_response) check_event_types(turn_response)

View file

@ -14,6 +14,9 @@ from pydantic import BaseModel
from termcolor import colored from termcolor import colored
from llama_stack.distribution.datatypes import Provider from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
class ProviderFixture(BaseModel): class ProviderFixture(BaseModel):
@ -21,6 +24,21 @@ class ProviderFixture(BaseModel):
provider_data: Optional[Dict[str, Any]] = None provider_data: Optional[Dict[str, Any]] = None
def remote_stack_fixture() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
config=RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
).model_dump(),
)
],
)
def pytest_configure(config): def pytest_configure(config):
config.option.tbstyle = "short" config.option.tbstyle = "short"
config.option.disable_warnings = True config.option.disable_warnings = True

View file

@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -29,6 +29,11 @@ def inference_model(request):
return request.config.getoption("--inference-model", None) return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture: def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture:
) )
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"] INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def memory_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture: def memory_meta_reference() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture:
) )
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -6,6 +6,7 @@
import json import json
import os import os
import tempfile
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
async def resolve_impls_for_test_v2( async def resolve_impls_for_test_v2(
@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2(
providers=providers, providers=providers,
) )
run_config = parse_and_maybe_upgrade_config(run_config) run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
if provider_data: if provider_data:
set_request_provider_data( set_request_provider_data(

View file

@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together", id="together",
marks=pytest.mark.together, marks=pytest.mark.together,
), ),
pytest.param(
{
"inference": "remote",
"safety": "remote",
},
id="remote",
marks=pytest.mark.remote,
),
] ]
def pytest_configure(config): def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together"]: for mark in ["meta_reference", "ollama", "together", "remote"]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
f"{mark}: marks tests as {mark} specific", f"{mark}: marks tests as {mark} specific",

View file

@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_model(request): def safety_model(request):
if hasattr(request, "param"): if hasattr(request, "param"):
@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture:
) )
SAFETY_FIXTURES = ["meta_reference", "together"] SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -27,7 +27,7 @@ class TestSafety:
for shield in response: for shield in response:
assert isinstance(shield, ShieldDefWithProvider) assert isinstance(shield, ShieldDefWithProvider)
assert shield.type in [v.value for v in ShieldType] assert shield.shield_type in [v.value for v in ShieldType]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_shield(self, safety_stack): async def test_run_shield(self, safety_stack):