mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
add support for remote providers in tests
This commit is contained in:
parent
0763a0b85f
commit
7cf4c905f3
11 changed files with 79 additions and 15 deletions
|
@ -83,6 +83,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||||
j = response.json()
|
j = response.json()
|
||||||
if j is None:
|
if j is None:
|
||||||
return None
|
return None
|
||||||
|
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||||
return parse_obj_as(return_type, j)
|
return parse_obj_as(return_type, j)
|
||||||
|
|
||||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||||
|
@ -102,14 +103,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, "red")
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield parse_obj_as(return_type, json.loads(data))
|
yield parse_obj_as(return_type, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
print(data)
|
||||||
|
|
||||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
|
@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "remote",
|
||||||
|
"safety": "remote",
|
||||||
|
"memory": "remote",
|
||||||
|
"agents": "remote",
|
||||||
|
},
|
||||||
|
id="remote",
|
||||||
|
marks=pytest.mark.remote,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for mark in ["meta_reference", "ollama", "together"]:
|
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
f"{mark}: marks tests as {mark} specific",
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
|
|
@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import (
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agents_remote() -> ProviderFixture:
|
||||||
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
AGENTS_FIXTURES = ["meta_reference"]
|
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -109,7 +109,6 @@ class TestAgents:
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
check_event_types(turn_response)
|
check_event_types(turn_response)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,9 @@ from pydantic import BaseModel
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Provider
|
from llama_stack.distribution.datatypes import Provider
|
||||||
|
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
from .env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
class ProviderFixture(BaseModel):
|
class ProviderFixture(BaseModel):
|
||||||
|
@ -21,6 +24,21 @@ class ProviderFixture(BaseModel):
|
||||||
provider_data: Optional[Dict[str, Any]] = None
|
provider_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_stack_fixture() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="remote",
|
||||||
|
provider_type="remote",
|
||||||
|
config=RemoteProviderConfig(
|
||||||
|
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
||||||
|
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
config.option.tbstyle = "short"
|
config.option.tbstyle = "short"
|
||||||
config.option.disable_warnings = True
|
config.option.disable_warnings = True
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
from ..conftest import ProviderFixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,11 @@ def inference_model(request):
|
||||||
return request.config.getoption("--inference-model", None)
|
return request.config.getoption("--inference-model", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_remote() -> ProviderFixture:
|
||||||
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
|
@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
from ..conftest import ProviderFixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_remote() -> ProviderFixture:
|
||||||
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_meta_reference() -> ProviderFixture:
|
def memory_meta_reference() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
async def resolve_impls_for_test_v2(
|
||||||
|
@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2(
|
||||||
providers=providers,
|
providers=providers,
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
impls = await resolve_impls(run_config, get_provider_registry())
|
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
|
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
|
||||||
|
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||||
|
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||||
|
|
||||||
if provider_data:
|
if provider_data:
|
||||||
set_request_provider_data(
|
set_request_provider_data(
|
||||||
|
|
|
@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "remote",
|
||||||
|
"safety": "remote",
|
||||||
|
},
|
||||||
|
id="remote",
|
||||||
|
marks=pytest.mark.remote,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for mark in ["meta_reference", "ollama", "together"]:
|
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
f"{mark}: marks tests as {mark} specific",
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
|
|
@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import (
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
|
||||||
from ..conftest import ProviderFixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def safety_remote() -> ProviderFixture:
|
||||||
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_model(request):
|
def safety_model(request):
|
||||||
if hasattr(request, "param"):
|
if hasattr(request, "param"):
|
||||||
|
@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAFETY_FIXTURES = ["meta_reference", "together"]
|
SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -27,7 +27,7 @@ class TestSafety:
|
||||||
|
|
||||||
for shield in response:
|
for shield in response:
|
||||||
assert isinstance(shield, ShieldDefWithProvider)
|
assert isinstance(shield, ShieldDefWithProvider)
|
||||||
assert shield.type in [v.value for v in ShieldType]
|
assert shield.shield_type in [v.value for v in ShieldType]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_shield(self, safety_stack):
|
async def test_run_shield(self, safety_stack):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue