address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-11 16:24:37 -08:00
parent e167e9eb93
commit 5821ec9ef3
12 changed files with 61 additions and 76 deletions

View file

@ -84,24 +84,3 @@ def pytest_generate_tests(metafunc):
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
metafunc.parametrize("inference_stack", fixtures, indirect=True)
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if not model:
raise ValueError(
"No embedding model specified. Please provide a valid embedding model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "embedding_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
metafunc.parametrize("embedding_stack", fixtures, indirect=True)

View file

@ -37,13 +37,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
)
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
@pytest_asyncio.fixture(scope="session")
async def embedding_stack(request, embedding_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)

View file

@ -14,15 +14,15 @@ from llama_stack.apis.inference import EmbeddingsResponse, ModelType
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, embedding_model, embedding_stack):
inference_impl, models_impl = embedding_stack
model = await models_impl.get_model(embedding_model)
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=embedding_model,
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
@ -35,9 +35,9 @@ class TestEmbeddings:
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, embedding_model, embedding_stack):
inference_impl, models_impl = embedding_stack
model = await models_impl.get_model(embedding_model)
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
@ -45,7 +45,7 @@ class TestEmbeddings:
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=embedding_model,
model_id=inference_model,
contents=texts,
)