Allow specifying resources in StackRunConfig (#425)

# What does this PR do? 

This PR brings back the facility to not force registration of resources
onto the user. This is not just annoying but actually not feasible
sometimes. For example, you may have a Stack which boots up with private
providers for inference for models A and B. There is no way for the user
to actually know which model is being served by these providers now (to
be able to register it.)

How will this avoid the users needing to do registration? In a follow-up
diff, I will make sure I update the sample run.yaml files so they list
the models served by the distributions explicitly. So when users do
`llama stack build --template <...>` and run it, their distributions
come up with the right set of models they expect.

For self-hosted distributions, it also allows us to have a place to
explicit list the models that need to be served to make the "complete"
stack (including safety, e.g.)

## Test Plan

Started ollama locally with two lightweight models: Llama3.2-3B-Instruct
and Llama-Guard-3-1B.

Updated all the tests including agents. Here's the tests I ran so far:

```bash
pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \
  --env FIREWORKS_API_KEY=...

pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference 

pytest -s -v -m ollama test_safety.py

pytest -s -v -m faiss test_memory.py

pytest -s -v -m ollama  test_agents.py \
  --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B
```

Found a few bugs here and there pre-existing that these test runs fixed.
This commit is contained in:
Ashwin Bharambe 2024-11-12 10:58:49 -08:00 committed by GitHub
parent 8035fa1869
commit d9d271a684
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 221 additions and 124 deletions

View file

@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import construct_stack
from .endpoints import get_all_api_endpoints
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
data = data.model_dump_json()
else:
data = json.dumps(data)
@ -281,12 +277,8 @@ def main(
app = FastAPI()
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)