Allow specifying resources in StackRunConfig (#425)

# What does this PR do? 

This PR brings back the facility to not force registration of resources
onto the user. This is not just annoying but actually not feasible
sometimes. For example, you may have a Stack which boots up with private
providers for inference for models A and B. There is no way for the user
to actually know which model is being served by these providers now (to
be able to register it.)

How will this avoid the users needing to do registration? In a follow-up
diff, I will make sure I update the sample run.yaml files so they list
the models served by the distributions explicitly. So when users do
`llama stack build --template <...>` and run it, their distributions
come up with the right set of models they expect.

For self-hosted distributions, it also allows us to have a place to
explicit list the models that need to be served to make the "complete"
stack (including safety, e.g.)

## Test Plan

Started ollama locally with two lightweight models: Llama3.2-3B-Instruct
and Llama-Guard-3-1B.

Updated all the tests including agents. Here's the tests I ran so far:

```bash
pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \
  --env FIREWORKS_API_KEY=...

pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference 

pytest -s -v -m ollama test_safety.py

pytest -s -v -m faiss test_memory.py

pytest -s -v -m ollama  test_agents.py \
  --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B
```

Found a few bugs here and there pre-existing that these test runs fixed.
This commit is contained in:
Ashwin Bharambe 2024-11-12 10:58:49 -08:00 committed by GitHub
parent 8035fa1869
commit d9d271a684
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 221 additions and 124 deletions

View file

@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[Model] = Field(default_factory=list)
shields: List[Shield] = Field(default_factory=list)
memory_banks: List[MemoryBank] = Field(default_factory=list)
datasets: List[Dataset] = Field(default_factory=list)
scoring_fns: List[ScoringFn] = Field(default_factory=list)
eval_tasks: List[EvalTask] = Field(default_factory=list)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import construct_stack
from .endpoints import get_all_api_endpoints
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
data = data.model_dump_json()
else:
data = json.dumps(data)
@ -281,12 +277,8 @@ def main(
app = FastAPI()
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
objects = [
*run_config.models,
*run_config.shields,
*run_config.memory_banks,
*run_config.datasets,
*run_config.scoring_fns,
*run_config.eval_tasks,
]
for obj in objects:
await dist_registry.register(obj)
resources = [
("models", Api.models),
("shields", Api.shields),
("memory_banks", Api.memory_banks),
("datasets", Api.datasets),
("scoring_fns", Api.scoring_functions),
("eval_tasks", Api.eval_tasks),
]
for rsrc, api in resources:
if api not in impls:
continue
method = getattr(impls[api], f"list_{api.value}")
for obj in await method():
print(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
print("")
return impls

View file

@ -5,14 +5,11 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Protocol
from typing import Dict, List, Optional, Protocol
import pydantic
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
config: StackRunConfig,
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)