mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
get multiple providers working for meta-reference (for inference + safety)
This commit is contained in:
parent
60800bc09b
commit
6c7ea6e904
10 changed files with 136 additions and 95 deletions
|
@ -132,7 +132,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||
)
|
||||
|
||||
objs = self.registry[routing_key]
|
||||
|
|
|
@ -73,16 +73,20 @@ def pytest_addoption(parser):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
safety_model = metafunc.config.getoption("--safety-model")
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
|
||||
[pytest.param(safety_model, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
inference_model = metafunc.config.getoption("--inference-model")
|
||||
models = list(set({inference_model, safety_model}))
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(models, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
|
|
|
@ -25,16 +25,18 @@ from ..conftest import ProviderFixture
|
|||
def agents_meta_reference() -> ProviderFixture:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
# TODO: make this an in-memory store
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_path=sqlite_file.name,
|
||||
),
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
# TODO: make this an in-memory store
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_path=sqlite_file.name,
|
||||
),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
@ -49,7 +51,7 @@ async def agents_stack(request):
|
|||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = [fixture.provider]
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
|
|
|
@ -19,6 +19,13 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||
# multiple models when you need to run a safety model in addition to normal agent
|
||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||
if isinstance(inference_model, list):
|
||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||
assert inference_model is not None
|
||||
|
||||
return dict(
|
||||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.distribution.datatypes import Provider
|
|||
|
||||
|
||||
class ProviderFixture(BaseModel):
|
||||
provider: Provider
|
||||
providers: List[Provider]
|
||||
provider_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
|
|
|
@ -24,63 +24,80 @@ from ..env import get_env_or_fail
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_model(request):
|
||||
return request.config.getoption("--inference-model", None) or request.param
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||
# TODO: make this work with multiple models
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceInferenceConfig(
|
||||
model=inference_model,
|
||||
max_seq_len=512,
|
||||
create_distributed_process_group=False,
|
||||
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id=f"meta-reference-{i}",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceInferenceConfig(
|
||||
model=m,
|
||||
max_seq_len=4096,
|
||||
create_distributed_process_group=False,
|
||||
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
||||
).model_dump(),
|
||||
)
|
||||
for i, m in enumerate(inference_model)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama(inference_model) -> ProviderFixture:
|
||||
if inference_model == "Llama3.1-8B-Instruct":
|
||||
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
if "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig(
|
||||
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig(
|
||||
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_fireworks() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
config=FireworksImplConfig(
|
||||
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
config=FireworksImplConfig(
|
||||
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_together() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="together",
|
||||
provider_type="remote::together",
|
||||
config=TogetherImplConfig().model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="together",
|
||||
provider_type="remote::together",
|
||||
config=TogetherImplConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
||||
),
|
||||
|
@ -96,7 +113,7 @@ async def inference_stack(request):
|
|||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.inference],
|
||||
{"inference": [inference_fixture.provider.model_dump()]},
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
)
|
||||
|
||||
|
|
|
@ -22,39 +22,45 @@ from ..env import get_env_or_fail
|
|||
@pytest.fixture(scope="session")
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=FaissImplConfig().model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=FaissImplConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateConfig().model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||
|
@ -72,7 +78,7 @@ async def memory_stack(request):
|
|||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.memory],
|
||||
{"memory": [fixture.provider.model_dump()]},
|
||||
{"memory": fixture.providers},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.distribution.resolver import resolve_impls
|
|||
|
||||
async def resolve_impls_for_test_v2(
|
||||
apis: List[Api],
|
||||
providers: Dict[str, Provider],
|
||||
providers: Dict[str, List[Provider]],
|
||||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
run_config = dict(
|
||||
|
|
|
@ -81,7 +81,6 @@ def pytest_generate_tests(metafunc):
|
|||
)
|
||||
|
||||
if "safety_stack" in metafunc.fixturenames:
|
||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
|
|
|
@ -22,32 +22,38 @@ from ..env import get_env_or_fail
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_model(request):
|
||||
return request.config.getoption("--safety-model", None) or request.param
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--safety-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_meta_reference(safety_model) -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=SafetyConfig(
|
||||
llama_guard_shield=LlamaGuardShieldConfig(
|
||||
model=safety_model,
|
||||
),
|
||||
).model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=SafetyConfig(
|
||||
llama_guard_shield=LlamaGuardShieldConfig(
|
||||
model=safety_model,
|
||||
),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_together() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="together",
|
||||
provider_type="remote::together",
|
||||
config=TogetherSafetyConfig().model_dump(),
|
||||
),
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="together",
|
||||
provider_type="remote::together",
|
||||
config=TogetherSafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
||||
),
|
||||
|
@ -67,8 +73,8 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
||||
|
||||
providers = {
|
||||
"inference": [inference_fixture.provider],
|
||||
"safety": [safety_fixture.provider],
|
||||
"inference": inference_fixture.providers,
|
||||
"safety": safety_fixture.providers,
|
||||
}
|
||||
provider_data = {}
|
||||
if inference_fixture.provider_data:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue