Merge branch 'main' into eval_task_register

This commit is contained in:
Xi Yan 2024-11-06 21:50:09 -08:00
commit 283b5c1def
11 changed files with 32 additions and 188 deletions

View file

@ -25,15 +25,19 @@ class ProviderFixture(BaseModel):
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
config=RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
).model_dump(),
config=config.model_dump(),
)
],
)

View file

@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "together",
"safety": "meta_reference",
},
id="together",
marks=pytest.mark.together,

View file

@ -12,12 +12,10 @@ from llama_stack.providers.inline.meta_reference.safety import (
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.remote.safety.together import TogetherSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -49,23 +47,7 @@ def safety_meta_reference(safety_model) -> ProviderFixture:
)
@pytest.fixture(scope="session")
def safety_together() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherSafetyConfig().model_dump(),
)
],
provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
),
)
SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
SAFETY_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")