mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
get multiple providers working for meta-reference (for inference + safety)
This commit is contained in:
parent
60800bc09b
commit
6c7ea6e904
10 changed files with 136 additions and 95 deletions
|
@ -132,7 +132,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
else:
|
else:
|
||||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||||
)
|
)
|
||||||
|
|
||||||
objs = self.registry[routing_key]
|
objs = self.registry[routing_key]
|
||||||
|
|
|
@ -73,16 +73,20 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "inference_model" in metafunc.fixturenames:
|
safety_model = metafunc.config.getoption("--safety-model")
|
||||||
metafunc.parametrize(
|
|
||||||
"inference_model",
|
|
||||||
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
|
|
||||||
indirect=True,
|
|
||||||
)
|
|
||||||
if "safety_model" in metafunc.fixturenames:
|
if "safety_model" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"safety_model",
|
"safety_model",
|
||||||
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
|
[pytest.param(safety_model, id="")],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
if "inference_model" in metafunc.fixturenames:
|
||||||
|
inference_model = metafunc.config.getoption("--inference-model")
|
||||||
|
models = list(set({inference_model, safety_model}))
|
||||||
|
|
||||||
|
metafunc.parametrize(
|
||||||
|
"inference_model",
|
||||||
|
[pytest.param(models, id="")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "agents_stack" in metafunc.fixturenames:
|
if "agents_stack" in metafunc.fixturenames:
|
||||||
|
|
|
@ -25,16 +25,18 @@ from ..conftest import ProviderFixture
|
||||||
def agents_meta_reference() -> ProviderFixture:
|
def agents_meta_reference() -> ProviderFixture:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="meta-reference",
|
Provider(
|
||||||
provider_type="meta-reference",
|
provider_id="meta-reference",
|
||||||
config=MetaReferenceAgentsImplConfig(
|
provider_type="meta-reference",
|
||||||
# TODO: make this an in-memory store
|
config=MetaReferenceAgentsImplConfig(
|
||||||
persistence_store=SqliteKVStoreConfig(
|
# TODO: make this an in-memory store
|
||||||
db_path=sqlite_file.name,
|
persistence_store=SqliteKVStoreConfig(
|
||||||
),
|
db_path=sqlite_file.name,
|
||||||
).model_dump(),
|
),
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +51,7 @@ async def agents_stack(request):
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["inference", "safety", "memory", "agents"]:
|
for key in ["inference", "safety", "memory", "agents"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = [fixture.provider]
|
providers[key] = fixture.providers
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,13 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def common_params(inference_model):
|
def common_params(inference_model):
|
||||||
|
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||||
|
# multiple models when you need to run a safety model in addition to normal agent
|
||||||
|
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||||
|
if isinstance(inference_model, list):
|
||||||
|
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||||
|
assert inference_model is not None
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
instructions="You are a helpful assistant.",
|
instructions="You are a helpful assistant.",
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.distribution.datatypes import Provider
|
||||||
|
|
||||||
|
|
||||||
class ProviderFixture(BaseModel):
|
class ProviderFixture(BaseModel):
|
||||||
provider: Provider
|
providers: List[Provider]
|
||||||
provider_data: Optional[Dict[str, Any]] = None
|
provider_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,63 +24,80 @@ from ..env import get_env_or_fail
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_model(request):
|
def inference_model(request):
|
||||||
return request.config.getoption("--inference-model", None) or request.param
|
if hasattr(request, "param"):
|
||||||
|
return request.param
|
||||||
|
return request.config.getoption("--inference-model", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
# TODO: make this work with multiple models
|
inference_model = (
|
||||||
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
|
)
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="meta-reference",
|
Provider(
|
||||||
provider_type="meta-reference",
|
provider_id=f"meta-reference-{i}",
|
||||||
config=MetaReferenceInferenceConfig(
|
provider_type="meta-reference",
|
||||||
model=inference_model,
|
config=MetaReferenceInferenceConfig(
|
||||||
max_seq_len=512,
|
model=m,
|
||||||
create_distributed_process_group=False,
|
max_seq_len=4096,
|
||||||
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
create_distributed_process_group=False,
|
||||||
).model_dump(),
|
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
for i, m in enumerate(inference_model)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_ollama(inference_model) -> ProviderFixture:
|
def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
if inference_model == "Llama3.1-8B-Instruct":
|
inference_model = (
|
||||||
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
|
)
|
||||||
|
if "Llama3.1-8B-Instruct" in inference_model:
|
||||||
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="ollama",
|
Provider(
|
||||||
provider_type="remote::ollama",
|
provider_id="ollama",
|
||||||
config=OllamaImplConfig(
|
provider_type="remote::ollama",
|
||||||
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
config=OllamaImplConfig(
|
||||||
).model_dump(),
|
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_fireworks() -> ProviderFixture:
|
def inference_fireworks() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="fireworks",
|
Provider(
|
||||||
provider_type="remote::fireworks",
|
provider_id="fireworks",
|
||||||
config=FireworksImplConfig(
|
provider_type="remote::fireworks",
|
||||||
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
config=FireworksImplConfig(
|
||||||
).model_dump(),
|
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_together() -> ProviderFixture:
|
def inference_together() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="together",
|
Provider(
|
||||||
provider_type="remote::together",
|
provider_id="together",
|
||||||
config=TogetherImplConfig().model_dump(),
|
provider_type="remote::together",
|
||||||
),
|
config=TogetherImplConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
provider_data=dict(
|
provider_data=dict(
|
||||||
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
||||||
),
|
),
|
||||||
|
@ -96,7 +113,7 @@ async def inference_stack(request):
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": [inference_fixture.provider.model_dump()]},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -22,39 +22,45 @@ from ..env import get_env_or_fail
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_meta_reference() -> ProviderFixture:
|
def memory_meta_reference() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="meta-reference",
|
Provider(
|
||||||
provider_type="meta-reference",
|
provider_id="meta-reference",
|
||||||
config=FaissImplConfig().model_dump(),
|
provider_type="meta-reference",
|
||||||
),
|
config=FaissImplConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_pgvector() -> ProviderFixture:
|
def memory_pgvector() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="pgvector",
|
Provider(
|
||||||
provider_type="remote::pgvector",
|
provider_id="pgvector",
|
||||||
config=PGVectorConfig(
|
provider_type="remote::pgvector",
|
||||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
config=PGVectorConfig(
|
||||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||||
db=get_env_or_fail("PGVECTOR_DB"),
|
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||||
user=get_env_or_fail("PGVECTOR_USER"),
|
db=get_env_or_fail("PGVECTOR_DB"),
|
||||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
user=get_env_or_fail("PGVECTOR_USER"),
|
||||||
).model_dump(),
|
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_weaviate() -> ProviderFixture:
|
def memory_weaviate() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="weaviate",
|
Provider(
|
||||||
provider_type="remote::weaviate",
|
provider_id="weaviate",
|
||||||
config=WeaviateConfig().model_dump(),
|
provider_type="remote::weaviate",
|
||||||
),
|
config=WeaviateConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
provider_data=dict(
|
provider_data=dict(
|
||||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||||
|
@ -72,7 +78,7 @@ async def memory_stack(request):
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.memory],
|
[Api.memory],
|
||||||
{"memory": [fixture.provider.model_dump()]},
|
{"memory": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
async def resolve_impls_for_test_v2(
|
||||||
apis: List[Api],
|
apis: List[Api],
|
||||||
providers: Dict[str, Provider],
|
providers: Dict[str, List[Provider]],
|
||||||
provider_data: Optional[Dict[str, Any]] = None,
|
provider_data: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
|
|
|
@ -81,7 +81,6 @@ def pytest_generate_tests(metafunc):
|
||||||
)
|
)
|
||||||
|
|
||||||
if "safety_stack" in metafunc.fixturenames:
|
if "safety_stack" in metafunc.fixturenames:
|
||||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
"safety": SAFETY_FIXTURES,
|
"safety": SAFETY_FIXTURES,
|
||||||
|
|
|
@ -22,32 +22,38 @@ from ..env import get_env_or_fail
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_model(request):
|
def safety_model(request):
|
||||||
return request.config.getoption("--safety-model", None) or request.param
|
if hasattr(request, "param"):
|
||||||
|
return request.param
|
||||||
|
return request.config.getoption("--safety-model", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_meta_reference(safety_model) -> ProviderFixture:
|
def safety_meta_reference(safety_model) -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="meta-reference",
|
Provider(
|
||||||
provider_type="meta-reference",
|
provider_id="meta-reference",
|
||||||
config=SafetyConfig(
|
provider_type="meta-reference",
|
||||||
llama_guard_shield=LlamaGuardShieldConfig(
|
config=SafetyConfig(
|
||||||
model=safety_model,
|
llama_guard_shield=LlamaGuardShieldConfig(
|
||||||
),
|
model=safety_model,
|
||||||
).model_dump(),
|
),
|
||||||
),
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_together() -> ProviderFixture:
|
def safety_together() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
providers=[
|
||||||
provider_id="together",
|
Provider(
|
||||||
provider_type="remote::together",
|
provider_id="together",
|
||||||
config=TogetherSafetyConfig().model_dump(),
|
provider_type="remote::together",
|
||||||
),
|
config=TogetherSafetyConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
provider_data=dict(
|
provider_data=dict(
|
||||||
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
||||||
),
|
),
|
||||||
|
@ -67,8 +73,8 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
||||||
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": [inference_fixture.provider],
|
"inference": inference_fixture.providers,
|
||||||
"safety": [safety_fixture.provider],
|
"safety": safety_fixture.providers,
|
||||||
}
|
}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
if inference_fixture.provider_data:
|
if inference_fixture.provider_data:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue