get multiple providers working for meta-reference (for inference + safety)

This commit is contained in:
Ashwin Bharambe 2024-11-04 16:33:42 -08:00
parent 60800bc09b
commit 6c7ea6e904
10 changed files with 136 additions and 95 deletions

View file

@ -132,7 +132,7 @@ class CommonRoutingTableImpl(RoutingTable):
else: else:
provider_ids_str = f"provider: `{provider_ids[0]}`" provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError( raise ValueError(
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
) )
objs = self.registry[routing_key] objs = self.registry[routing_key]

View file

@ -73,16 +73,20 @@ def pytest_addoption(parser):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames: safety_model = metafunc.config.getoption("--safety-model")
metafunc.parametrize(
"inference_model",
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
indirect=True,
)
if "safety_model" in metafunc.fixturenames: if "safety_model" in metafunc.fixturenames:
metafunc.parametrize( metafunc.parametrize(
"safety_model", "safety_model",
[pytest.param(metafunc.config.getoption("--safety-model"), id="")], [pytest.param(safety_model, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model}))
metafunc.parametrize(
"inference_model",
[pytest.param(models, id="")],
indirect=True, indirect=True,
) )
if "agents_stack" in metafunc.fixturenames: if "agents_stack" in metafunc.fixturenames:

View file

@ -25,16 +25,18 @@ from ..conftest import ProviderFixture
def agents_meta_reference() -> ProviderFixture: def agents_meta_reference() -> ProviderFixture:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="meta-reference", Provider(
provider_type="meta-reference", provider_id="meta-reference",
config=MetaReferenceAgentsImplConfig( provider_type="meta-reference",
# TODO: make this an in-memory store config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig( # TODO: make this an in-memory store
db_path=sqlite_file.name, persistence_store=SqliteKVStoreConfig(
), db_path=sqlite_file.name,
).model_dump(), ),
), ).model_dump(),
)
],
) )
@ -49,7 +51,7 @@ async def agents_stack(request):
provider_data = {} provider_data = {}
for key in ["inference", "safety", "memory", "agents"]: for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = [fixture.provider] providers[key] = fixture.providers
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)

View file

@ -19,6 +19,13 @@ from llama_stack.providers.datatypes import * # noqa: F403
@pytest.fixture @pytest.fixture
def common_params(inference_model): def common_params(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
return dict( return dict(
model=inference_model, model=inference_model,
instructions="You are a helpful assistant.", instructions="You are a helpful assistant.",

View file

@ -17,7 +17,7 @@ from llama_stack.distribution.datatypes import Provider
class ProviderFixture(BaseModel): class ProviderFixture(BaseModel):
provider: Provider providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None provider_data: Optional[Dict[str, Any]] = None

View file

@ -24,63 +24,80 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_model(request): def inference_model(request):
return request.config.getoption("--inference-model", None) or request.param if hasattr(request, "param"):
return request.param
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture: def inference_meta_reference(inference_model) -> ProviderFixture:
# TODO: make this work with multiple models inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="meta-reference", Provider(
provider_type="meta-reference", provider_id=f"meta-reference-{i}",
config=MetaReferenceInferenceConfig( provider_type="meta-reference",
model=inference_model, config=MetaReferenceInferenceConfig(
max_seq_len=512, model=m,
create_distributed_process_group=False, max_seq_len=4096,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None), create_distributed_process_group=False,
).model_dump(), checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
), ).model_dump(),
)
for i, m in enumerate(inference_model)
]
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture: def inference_ollama(inference_model) -> ProviderFixture:
if inference_model == "Llama3.1-8B-Instruct": inference_model = (
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing") [inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="ollama", Provider(
provider_type="remote::ollama", provider_id="ollama",
config=OllamaImplConfig( provider_type="remote::ollama",
host="localhost", port=os.getenv("OLLAMA_PORT", 11434) config=OllamaImplConfig(
).model_dump(), host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
), ).model_dump(),
)
],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture: def inference_fireworks() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="fireworks", Provider(
provider_type="remote::fireworks", provider_id="fireworks",
config=FireworksImplConfig( provider_type="remote::fireworks",
api_key=get_env_or_fail("FIREWORKS_API_KEY"), config=FireworksImplConfig(
).model_dump(), api_key=get_env_or_fail("FIREWORKS_API_KEY"),
), ).model_dump(),
)
],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_together() -> ProviderFixture: def inference_together() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="together", Provider(
provider_type="remote::together", provider_id="together",
config=TogetherImplConfig().model_dump(), provider_type="remote::together",
), config=TogetherImplConfig().model_dump(),
)
],
provider_data=dict( provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"), together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
), ),
@ -96,7 +113,7 @@ async def inference_stack(request):
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2( impls = await resolve_impls_for_test_v2(
[Api.inference], [Api.inference],
{"inference": [inference_fixture.provider.model_dump()]}, {"inference": inference_fixture.providers},
inference_fixture.provider_data, inference_fixture.provider_data,
) )

View file

@ -22,39 +22,45 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture: def memory_meta_reference() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="meta-reference", Provider(
provider_type="meta-reference", provider_id="meta-reference",
config=FaissImplConfig().model_dump(), provider_type="meta-reference",
), config=FaissImplConfig().model_dump(),
)
],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_pgvector() -> ProviderFixture: def memory_pgvector() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="pgvector", Provider(
provider_type="remote::pgvector", provider_id="pgvector",
config=PGVectorConfig( provider_type="remote::pgvector",
host=os.getenv("PGVECTOR_HOST", "localhost"), config=PGVectorConfig(
port=os.getenv("PGVECTOR_PORT", 5432), host=os.getenv("PGVECTOR_HOST", "localhost"),
db=get_env_or_fail("PGVECTOR_DB"), port=os.getenv("PGVECTOR_PORT", 5432),
user=get_env_or_fail("PGVECTOR_USER"), db=get_env_or_fail("PGVECTOR_DB"),
password=get_env_or_fail("PGVECTOR_PASSWORD"), user=get_env_or_fail("PGVECTOR_USER"),
).model_dump(), password=get_env_or_fail("PGVECTOR_PASSWORD"),
), ).model_dump(),
)
],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_weaviate() -> ProviderFixture: def memory_weaviate() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="weaviate", Provider(
provider_type="remote::weaviate", provider_id="weaviate",
config=WeaviateConfig().model_dump(), provider_type="remote::weaviate",
), config=WeaviateConfig().model_dump(),
)
],
provider_data=dict( provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
@ -72,7 +78,7 @@ async def memory_stack(request):
impls = await resolve_impls_for_test_v2( impls = await resolve_impls_for_test_v2(
[Api.memory], [Api.memory],
{"memory": [fixture.provider.model_dump()]}, {"memory": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )

View file

@ -20,7 +20,7 @@ from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test_v2( async def resolve_impls_for_test_v2(
apis: List[Api], apis: List[Api],
providers: Dict[str, Provider], providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None, provider_data: Optional[Dict[str, Any]] = None,
): ):
run_config = dict( run_config = dict(

View file

@ -81,7 +81,6 @@ def pytest_generate_tests(metafunc):
) )
if "safety_stack" in metafunc.fixturenames: if "safety_stack" in metafunc.fixturenames:
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES, "safety": SAFETY_FIXTURES,

View file

@ -22,32 +22,38 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_model(request): def safety_model(request):
return request.config.getoption("--safety-model", None) or request.param if hasattr(request, "param"):
return request.param
return request.config.getoption("--safety-model", None)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture: def safety_meta_reference(safety_model) -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="meta-reference", Provider(
provider_type="meta-reference", provider_id="meta-reference",
config=SafetyConfig( provider_type="meta-reference",
llama_guard_shield=LlamaGuardShieldConfig( config=SafetyConfig(
model=safety_model, llama_guard_shield=LlamaGuardShieldConfig(
), model=safety_model,
).model_dump(), ),
), ).model_dump(),
)
],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_together() -> ProviderFixture: def safety_together() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
provider=Provider( providers=[
provider_id="together", Provider(
provider_type="remote::together", provider_id="together",
config=TogetherSafetyConfig().model_dump(), provider_type="remote::together",
), config=TogetherSafetyConfig().model_dump(),
)
],
provider_data=dict( provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"), together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
), ),
@ -67,8 +73,8 @@ async def safety_stack(inference_model, safety_model, request):
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}") safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
providers = { providers = {
"inference": [inference_fixture.provider], "inference": inference_fixture.providers,
"safety": [safety_fixture.provider], "safety": safety_fixture.providers,
} }
provider_data = {} provider_data = {}
if inference_fixture.provider_data: if inference_fixture.provider_data: