Fix agents tests

This commit is contained in:
Ashwin Bharambe 2024-11-11 22:31:09 -08:00
parent 38257a9cbe
commit abe1cc6303
7 changed files with 51 additions and 24 deletions

View file

@ -9,6 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
return inference_model
@pytest.fixture(scope="session")
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request):
async def agents_stack(request, inference_model, safety_model):
fixture_dict = request.param
providers = {}
@ -60,9 +71,28 @@ async def agents_stack(request):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inf_provider_id = providers["inference"][0].provider_id
safety_provider_id = providers["safety"][0].provider_id
shield = get_shield_to_register(
providers["safety"][0].provider_type, safety_provider_id, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
impls = await resolve_impls_for_test_v2(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
models=[
Model(
identifier=model,
provider_id=inf_provider_id,
provider_resource_id=model,
)
for model in inference_models
],
shields=[shield],
)
return impls[Api.agents], impls[Api.memory]