llama-stack/llama_stack/providers/tests/inference/fixtures.py
Ashwin Bharambe d9d271a684
Allow specifying resources in StackRunConfig (#425)
# What does this PR do? 

This PR brings back the facility to not force registration of resources
onto the user. This is not just annoying but actually not feasible
sometimes. For example, you may have a Stack which boots up with private
providers for inference for models A and B. There is no way for the user
to actually know which model is being served by these providers now (to
be able to register it.)

How will this avoid the users needing to do registration? In a follow-up
diff, I will make sure I update the sample run.yaml files so they list
the models served by the distributions explicitly. So when users do
`llama stack build --template <...>` and run it, their distributions
come up with the right set of models they expect.

For self-hosted distributions, it also allows us to have a place to
explicit list the models that need to be served to make the "complete"
stack (including safety, e.g.)

## Test Plan

Started ollama locally with two lightweight models: Llama3.2-3B-Instruct
and Llama-Guard-3-1B.

Updated all the tests including agents. Here's the tests I ran so far:

```bash
pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \
  --env FIREWORKS_API_KEY=...

pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference 

pytest -s -v -m ollama test_safety.py

pytest -s -v -m faiss test_memory.py

pytest -s -v -m ollama  test_agents.py \
  --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B
```

Found a few bugs here and there pre-existing that these test runs fixed.
2024-11-12 10:58:49 -08:00

173 lines
5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def inference_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
providers=[
Provider(
provider_id=f"meta-reference-{i}",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=m,
max_seq_len=4096,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
providers=[
Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote::vllm",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_together() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
],
provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
),
)
@pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="bedrock",
provider_type="remote::bedrock",
config=BedrockConfig().model_dump(),
)
],
)
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
"fireworks",
"together",
"vllm_remote",
"remote",
"bedrock",
]
@pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
Model(
identifier=inference_model,
provider_resource_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
)
return (impls[Api.inference], impls[Api.models])