Allow specifying resources in StackRunConfig (#425)

# What does this PR do? 

This PR brings back the facility to not force registration of resources
onto the user. This is not just annoying but actually not feasible
sometimes. For example, you may have a Stack which boots up with private
providers for inference for models A and B. There is no way for the user
to actually know which model is being served by these providers now (to
be able to register it.)

How will this avoid the users needing to do registration? In a follow-up
diff, I will make sure I update the sample run.yaml files so they list
the models served by the distributions explicitly. So when users do
`llama stack build --template <...>` and run it, their distributions
come up with the right set of models they expect.

For self-hosted distributions, it also allows us to have a place to
explicit list the models that need to be served to make the "complete"
stack (including safety, e.g.)

## Test Plan

Started ollama locally with two lightweight models: Llama3.2-3B-Instruct
and Llama-Guard-3-1B.

Updated all the tests including agents. Here's the tests I ran so far:

```bash
pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \
  --env FIREWORKS_API_KEY=...

pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference 

pytest -s -v -m ollama test_safety.py

pytest -s -v -m faiss test_memory.py

pytest -s -v -m ollama  test_agents.py \
  --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B
```

Found a few bugs here and there pre-existing that these test runs fixed.
This commit is contained in:
Ashwin Bharambe 2024-11-12 10:58:49 -08:00 committed by GitHub
parent 8035fa1869
commit d9d271a684
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 221 additions and 124 deletions

View file

@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type
schema_utils.json_schema_type = json_schema_type
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
from llama_stack.distribution.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations

View file

@ -22,6 +22,9 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
# TODO: I think we need to move these into the child classes
# and make them `model_id`, `shield_id`, etc. because otherwise
# the config file has these confusing generic names in there
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)

View file

@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[Model] = Field(default_factory=list)
shields: List[Shield] = Field(default_factory=list)
memory_banks: List[MemoryBank] = Field(default_factory=list)
datasets: List[Dataset] = Field(default_factory=list)
scoring_fns: List[ScoringFn] = Field(default_factory=list)
eval_tasks: List[EvalTask] = Field(default_factory=list)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import construct_stack
from .endpoints import get_all_api_endpoints
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
data = data.model_dump_json()
else:
data = json.dumps(data)
@ -281,12 +277,8 @@ def main(
app = FastAPI()
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
objects = [
*run_config.models,
*run_config.shields,
*run_config.memory_banks,
*run_config.datasets,
*run_config.scoring_fns,
*run_config.eval_tasks,
]
for obj in objects:
await dist_registry.register(obj)
resources = [
("models", Api.models),
("shields", Api.shields),
("memory_banks", Api.memory_banks),
("datasets", Api.datasets),
("scoring_fns", Api.scoring_functions),
("eval_tasks", Api.eval_tasks),
]
for rsrc, api in resources:
if api not in impls:
continue
method = getattr(impls[api], f"list_{api.value}")
for obj in await method():
print(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
print("")
return impls

View file

@ -5,14 +5,11 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Protocol
from typing import Dict, List, Optional, Protocol
import pydantic
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
config: StackRunConfig,
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)

View file

@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
memory_bank = VectorMemoryBank(
identifier=bank_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.memory_banks_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id

View file

@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "meta_reference",
"safety": "llama_guard",
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="meta_reference",
@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "ollama",
"safety": "llama_guard",
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="ollama",
@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="together",

View file

@ -9,6 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
return inference_model
@pytest.fixture(scope="session")
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request):
async def agents_stack(request, inference_model, safety_model):
fixture_dict = request.param
providers = {}
@ -60,9 +71,28 @@ async def agents_stack(request):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inf_provider_id = providers["inference"][0].provider_id
safety_provider_id = providers["safety"][0].provider_id
shield = get_shield_to_register(
providers["safety"][0].provider_type, safety_provider_id, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
impls = await resolve_impls_for_test_v2(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
models=[
Model(
identifier=model,
provider_id=inf_provider_id,
provider_resource_id=model,
)
for model in inference_models
],
shields=[shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
# -m "meta_reference"
from .fixtures import pick_inference_model
@pytest.fixture
def common_params(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,

View file

@ -9,6 +9,8 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
models=[
Model(
identifier=inference_model,
provider_resource_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
)
return (impls[Api.inference], impls[Api.models])

View file

@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
def memory_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture:
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")

View file

@ -44,7 +44,6 @@ def sample_documents():
async def register_memory_bank(banks_impl: MemoryBanks):
return await banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
@ -71,7 +70,7 @@ class TestMemory:
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
bank = await banks_impl.register_memory_bank(
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",

View file

@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
models=models or [],
shields=shields or [],
memory_banks=memory_banks or [],
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
try:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
impls = await construct_stack(run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e

View file

@ -7,7 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.apis.models import Model
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield_to_register(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[
Model(
identifier=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
provider_resource_id=inference_model,
)
],
shields=[shield],
)
safety_impl = impls[Api.safety]
shields_impl = impls[Api.shields]
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
return impls[Api.safety], impls[Api.shields], shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -134,8 +133,10 @@ async def create_and_register_shield(
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
return await shields_impl.register_shield(
shield_id=identifier,
return Shield(
identifier=identifier,
shield_type=shield_type,
params=shield_config,
provider_id=provider_id,
provider_resource_id=identifier,
)