get multiple providers working for meta-reference (for inference + safety)

This commit is contained in:
Ashwin Bharambe 2024-11-04 16:33:42 -08:00
parent 60800bc09b
commit 6c7ea6e904
10 changed files with 136 additions and 95 deletions

View file

@ -132,7 +132,7 @@ class CommonRoutingTableImpl(RoutingTable):
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
)
objs = self.registry[routing_key]

View file

@ -73,16 +73,20 @@ def pytest_addoption(parser):
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
metafunc.parametrize(
"inference_model",
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
indirect=True,
)
safety_model = metafunc.config.getoption("--safety-model")
if "safety_model" in metafunc.fixturenames:
metafunc.parametrize(
"safety_model",
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
[pytest.param(safety_model, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model}))
metafunc.parametrize(
"inference_model",
[pytest.param(models, id="")],
indirect=True,
)
if "agents_stack" in metafunc.fixturenames:

View file

@ -25,16 +25,18 @@ from ..conftest import ProviderFixture
def agents_meta_reference() -> ProviderFixture:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceAgentsImplConfig(
# TODO: make this an in-memory store
persistence_store=SqliteKVStoreConfig(
db_path=sqlite_file.name,
),
).model_dump(),
),
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceAgentsImplConfig(
# TODO: make this an in-memory store
persistence_store=SqliteKVStoreConfig(
db_path=sqlite_file.name,
),
).model_dump(),
)
],
)
@ -49,7 +51,7 @@ async def agents_stack(request):
provider_data = {}
for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = [fixture.provider]
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)

View file

@ -19,6 +19,13 @@ from llama_stack.providers.datatypes import * # noqa: F403
@pytest.fixture
def common_params(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
return dict(
model=inference_model,
instructions="You are a helpful assistant.",

View file

@ -17,7 +17,7 @@ from llama_stack.distribution.datatypes import Provider
class ProviderFixture(BaseModel):
provider: Provider
providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None

View file

@ -24,63 +24,80 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def inference_model(request):
return request.config.getoption("--inference-model", None) or request.param
if hasattr(request, "param"):
return request.param
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
# TODO: make this work with multiple models
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=inference_model,
max_seq_len=512,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
),
providers=[
Provider(
provider_id=f"meta-reference-{i}",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=m,
max_seq_len=4096,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
if inference_model == "Llama3.1-8B-Instruct":
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
provider=Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump(),
),
providers=[
Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
),
providers=[
Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_together() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
),
providers=[
Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
],
provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
),
@ -96,7 +113,7 @@ async def inference_stack(request):
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2(
[Api.inference],
{"inference": [inference_fixture.provider.model_dump()]},
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
)

View file

@ -22,39 +22,45 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
),
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_pgvector() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
),
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_weaviate() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
),
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
@ -72,7 +78,7 @@ async def memory_stack(request):
impls = await resolve_impls_for_test_v2(
[Api.memory],
{"memory": [fixture.provider.model_dump()]},
{"memory": fixture.providers},
fixture.provider_data,
)

View file

@ -20,7 +20,7 @@ from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, Provider],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
):
run_config = dict(

View file

@ -81,7 +81,6 @@ def pytest_generate_tests(metafunc):
)
if "safety_stack" in metafunc.fixturenames:
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,

View file

@ -22,32 +22,38 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def safety_model(request):
return request.config.getoption("--safety-model", None) or request.param
if hasattr(request, "param"):
return request.param
return request.config.getoption("--safety-model", None)
@pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=SafetyConfig(
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
),
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=SafetyConfig(
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def safety_together() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherSafetyConfig().model_dump(),
),
providers=[
Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherSafetyConfig().model_dump(),
)
],
provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
),
@ -67,8 +73,8 @@ async def safety_stack(inference_model, safety_model, request):
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
providers = {
"inference": [inference_fixture.provider],
"safety": [safety_fixture.provider],
"inference": inference_fixture.providers,
"safety": safety_fixture.providers,
}
provider_data = {}
if inference_fixture.provider_data: