Merge branch 'main' into kill_configure

This commit is contained in:
Xi Yan 2024-11-05 10:15:01 -08:00
commit 47d91b10fb
29 changed files with 119 additions and 1463 deletions

View file

@ -134,7 +134,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -37,7 +37,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
type=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
)
]

View file

@ -25,8 +25,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type != ShieldType.code_scanner.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
if shield.shield_type != ShieldType.code_scanner.value:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
async def run_shield(
self,

View file

@ -49,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
return [
ShieldDef(
identifier=shield_type,
type=shield_type,
shield_type=shield_type,
params={},
)
for shield_type in self.available_shields
@ -92,14 +92,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
if shield.type == ShieldType.llama_guard.value:
if shield.shield_type == ShieldType.llama_guard.value:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
)
elif shield.type == ShieldType.prompt_guard.value:
elif shield.shield_type == ShieldType.prompt_guard.value:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":
@ -109,4 +109,4 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
else:
raise ValueError(f"Unknown prompt guard type: {subtype}")
else:
raise ValueError(f"Unknown shield type: {shield.type}")
raise ValueError(f"Unknown shield type: {shield.shield_type}")

View file

@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together",
marks=pytest.mark.together,
),
pytest.param(
{
"inference": "remote",
"safety": "remote",
"memory": "remote",
"agents": "remote",
},
id="remote",
marks=pytest.mark.remote,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together"]:
for mark in ["meta_reference", "ollama", "together", "remote"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def agents_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture:
)
AGENTS_FIXTURES = ["meta_reference"]
AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")

View file

@ -109,7 +109,6 @@ class TestAgents:
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
check_event_types(turn_response)

View file

@ -14,6 +14,9 @@ from pydantic import BaseModel
from termcolor import colored
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
class ProviderFixture(BaseModel):
@ -21,6 +24,21 @@ class ProviderFixture(BaseModel):
provider_data: Optional[Dict[str, Any]] = None
def remote_stack_fixture() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
config=RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
).model_dump(),
)
],
)
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True

View file

@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@ -29,6 +29,11 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture:
)
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"]
@pytest_asyncio.fixture(scope="session")

View file

@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def memory_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
return ProviderFixture(
@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture:
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
@pytest_asyncio.fixture(scope="session")

View file

@ -6,6 +6,7 @@
import json
import os
import tempfile
from datetime import datetime
from typing import Any, Dict, List, Optional
@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2(
providers=providers,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
if provider_data:
set_request_provider_data(

View file

@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together",
marks=pytest.mark.together,
),
pytest.param(
{
"inference": "remote",
"safety": "remote",
},
id="remote",
marks=pytest.mark.remote,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together"]:
for mark in ["meta_reference", "ollama", "together", "remote"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def safety_model(request):
if hasattr(request, "param"):
@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture:
)
SAFETY_FIXTURES = ["meta_reference", "together"]
SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
@pytest_asyncio.fixture(scope="session")

View file

@ -27,7 +27,7 @@ class TestSafety:
for shield in response:
assert isinstance(shield, ShieldDefWithProvider)
assert shield.type in [v.value for v in ShieldType]
assert shield.shield_type in [v.value for v in ShieldType]
@pytest.mark.asyncio
async def test_run_shield(self, safety_stack):