Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-05-06 09:56:22 -05:00 committed by GitHub
commit 21125f725f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 301 additions and 182 deletions

26
.github/actions/setup-ollama/action.yml vendored Normal file
View file

@ -0,0 +1,26 @@
name: Setup Ollama
description: Start Ollama and cache model
inputs:
models:
description: Comma-separated list of models to pull
default: "llama3.2:3b-instruct-fp16,all-minilm:latest"
runs:
using: "composite"
steps:
- name: Install and start Ollama
shell: bash
run: |
# the ollama installer also starts the ollama service
curl -fsSL https://ollama.com/install.sh | sh
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
# pull them directly.
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
- name: Pull requested models
if: inputs.models != ''
shell: bash
run: |
for model in $(echo "${{ inputs.models }}" | tr ',' ' '); do
ollama pull "$model"
done

View file

@ -38,19 +38,8 @@ jobs:
python-version: "3.10"
activate-environment: true
- name: Install and start Ollama
run: |
# the ollama installer also starts the ollama service
curl -fsSL https://ollama.com/install.sh | sh
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
# pull them directly.
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
- name: Pull Ollama models (instruct and embed)
run: |
ollama pull llama3.2:3b-instruct-fp16
ollama pull all-minilm:latest
- name: Setup ollama
uses: ./.github/actions/setup-ollama
- name: Set Up Environment and Install Dependencies
run: |

View file

@ -55,6 +55,7 @@ Here's a list of known external providers that you can use with Llama Stack:
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
### Remote Provider Specification

View file

@ -27,5 +27,81 @@ You can install Milvus using pymilvus:
```bash
pip install pymilvus
```
## Configuration
In Llama Stack, Milvus can be configured in two ways:
- **Inline (Local) Configuration** - Uses Milvus-Lite for local storage
- **Remote Configuration** - Connects to a remote Milvus server
### Inline (Local) Configuration
The simplest method is local configuration, which requires setting `db_path`, a path for locally storing Milvus-Lite files:
```yaml
vector_io:
- provider_id: milvus
provider_type: inline::milvus
config:
db_path: ~/.llama/distributions/together/milvus_store.db
```
### Remote Configuration
Remote configuration is suitable for larger data storage requirements:
#### Standard Remote Connection
```yaml
vector_io:
- provider_id: milvus
provider_type: remote::milvus
config:
uri: "http://<host>:<port>"
token: "<user>:<password>"
```
#### TLS-Enabled Remote Connection (One-way TLS)
For connections to Milvus instances with one-way TLS enabled:
```yaml
vector_io:
- provider_id: milvus
provider_type: remote::milvus
config:
uri: "https://<host>:<port>"
token: "<user>:<password>"
secure: True
server_pem_path: "/path/to/server.pem"
```
#### Mutual TLS (mTLS) Remote Connection
For connections to Milvus instances with mutual TLS (mTLS) enabled:
```yaml
vector_io:
- provider_id: milvus
provider_type: remote::milvus
config:
uri: "https://<host>:<port>"
token: "<user>:<password>"
secure: True
ca_pem_path: "/path/to/ca.pem"
client_pem_path: "/path/to/client.pem"
client_key_path: "/path/to/client.key"
```
#### Key Parameters for TLS Configuration
- **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`.
- **`server_pem_path`**: Path to the **server certificate** for verifying the servers identity (used in one-way TLS).
- **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS).
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from llama_stack.schema_utils import json_schema_type
@ -17,6 +17,8 @@ class MilvusVectorIOConfig(BaseModel):
token: str | None = None
consistency_level: str = "Strong"
model_config = ConfigDict(extra="allow")
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}

View file

@ -87,6 +87,7 @@ test = [
"mcp",
"datasets",
"autoevals",
"transformers",
]
docs = [
"sphinx-autobuild",

View file

@ -31,7 +31,6 @@ def data_url_from_file(file_path: str) -> str:
return data_url
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
@pytest.mark.parametrize(
"purpose, source, provider_id, limit",
[

5
tests/unit/__init__.py Normal file
View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

9
tests/unit/conftest.py Normal file
View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401

View file

@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import (
ToolGroupsRoutingTable,
VectorDBsRoutingTable,
)
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def dist_registry(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
yield registry
class Impl:
@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple models and verify listing
@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry)
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple shields and verify listing
@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry)
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry):
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry)
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple datasets and verify listing
@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry)
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple scoring functions and verify listing
@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry)
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple benchmarks and verify listing
@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry)
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple tool groups and verify listing

34
tests/unit/fixtures.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
yield kvstore
@pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import uuid
from datetime import datetime
from unittest.mock import patch
@ -17,20 +14,12 @@ from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
yield agent_persistence
shutil.rmtree(temp_dir)
@pytest.mark.asyncio

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
@ -20,28 +18,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
def config():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
if os.path.exists(config.db_path):
os.remove(config.db_path)
return config
@pytest_asyncio.fixture(scope="function")
async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest_asyncio.fixture(scope="function")
async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest.fixture
def sample_vector_db():
return VectorDB(
@ -63,41 +39,42 @@ def sample_model():
@pytest.mark.asyncio
async def test_registry_initialization(registry):
async def test_registry_initialization(disk_dist_registry):
# Test empty registry
result = await registry.get("nonexistent", "nonexistent")
result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_basic_registration(registry, sample_vector_db, sample_model):
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await registry.register(sample_vector_db)
await disk_dist_registry.register(sample_vector_db)
print(f"Registering {sample_model}")
await registry.register(sample_model)
await disk_dist_registry.register(sample_model)
print("Getting vector_db")
result_vector_db = await registry.get("vector_db", "test_vector_db")
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id
result_model = await registry.get("model", "test_model")
result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None
assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize()
await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_model)
# Test cached version loads from disk
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
db_path = sqlite_kvstore.db_path
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
@ -109,10 +86,7 @@ async def test_cached_registry_initialization(config, sample_vector_db, sample_m
@pytest.mark.asyncio
async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
@ -120,16 +94,17 @@ async def test_cached_registry_updates(config):
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(new_vector_db)
await cached_disk_dist_registry.register(new_vector_db)
# Verify in cache
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
db_path = cached_disk_dist_registry.kvstore.db_path
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
@ -138,10 +113,7 @@ async def test_cached_registry_updates(config):
@pytest.mark.asyncio
async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
@ -149,7 +121,7 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(original_vector_db)
await cached_disk_dist_registry.register(original_vector_db)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -158,18 +130,16 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_vector_db)
await cached_disk_dist_registry.register(duplicate_vector_db)
result = await cached_registry.get("vector_db", "test_vector_db_2")
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
test_vector_dbs = [
VectorDB(
@ -184,10 +154,10 @@ async def test_get_all_objects(config):
# Register all vector_dbs
for vector_db in test_vector_dbs:
await cached_registry.register(vector_db)
await cached_disk_dist_registry.register(vector_db)
# Test get_all retrieval
all_results = await cached_registry.get_all()
all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3
# Verify each vector_db was stored correctly
@ -201,9 +171,7 @@ async def test_get_all_objects(config):
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config):
kvstore = await kvstore_impl(config)
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
@ -212,16 +180,18 @@ async def test_parse_registry_values_error_handling(config):
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set(
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(kvstore)
test_registry = DiskDistributionRegistry(sqlite_kvstore)
await test_registry.initialize()
# Get all objects, which should only return the valid one
@ -240,9 +210,7 @@ async def test_parse_registry_values_error_handling(config):
@pytest.mark.asyncio
async def test_cached_registry_error_handling(config):
kvstore = await kvstore_impl(config)
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
@ -251,14 +219,16 @@ async def test_cached_registry_error_handling(config):
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await kvstore.set(
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(kvstore)
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await cached_registry.initialize()
all_objects = await cached_registry.get_all()

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import pytest
@ -14,30 +11,10 @@ from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
async def kvstore():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_registry_acl.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
yield kvstore
shutil.rmtree(temp_dir)
@pytest.fixture(scope="function")
async def registry(kvstore):
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
return registry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(registry):
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL(
identifier="model-acl",
provider_id="test-provider",
@ -46,30 +23,30 @@ async def test_registry_cache_with_acl(registry):
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
)
success = await registry.register(model)
success = await cached_disk_dist_registry.register(model)
assert success
cached_model = registry.get_cached("model", "model-acl")
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None
assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"]
fetched_model = await registry.get("model", "model-acl")
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None
assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await registry.update(model)
await cached_disk_dist_registry.update(model)
updated_cached = registry.get_cached("model", "model-acl")
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize()
new_model = await new_registry.get("model", "model-acl")
@ -81,7 +58,7 @@ async def test_registry_cache_with_acl(registry):
@pytest.mark.asyncio
async def test_registry_empty_acl(registry):
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL(
identifier="model-empty-acl",
provider_id="test-provider",
@ -90,9 +67,9 @@ async def test_registry_empty_acl(registry):
access_attributes=AccessAttributes(),
)
await registry.register(model)
await cached_disk_dist_registry.register(model)
cached_model = registry.get_cached("model", "model-empty-acl")
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None
assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None
@ -100,7 +77,7 @@ async def test_registry_empty_acl(registry):
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
all_models = await registry.get_all()
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1
model = ModelWithACL(
@ -110,18 +87,18 @@ async def test_registry_empty_acl(registry):
model_type=ModelType.llm,
)
await registry.register(model)
await cached_disk_dist_registry.register(model)
cached_model = registry.get_cached("model", "model-no-acl")
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None
assert cached_model.access_attributes is None
all_models = await registry.get_all()
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(registry):
async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
@ -137,9 +114,9 @@ async def test_registry_serialization(registry):
access_attributes=attributes,
)
await registry.register(model)
await cached_disk_dist_registry.register(model)
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize()
loaded_model = await new_registry.get("model", "model-serialize")

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -15,9 +12,6 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
class AsyncMock(MagicMock):
@ -30,25 +24,16 @@ def _return_model(model):
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
async def test_setup(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=registry,
dist_registry=cached_disk_dist_registry,
)
yield registry, routing_table
shutil.rmtree(temp_dir)
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio

70
uv.lock generated
View file

@ -1493,6 +1493,7 @@ test = [
{ name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" },
{ name = "torchvision", version = "0.21.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "torchvision", version = "0.21.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "transformers" },
]
ui = [
{ name = "llama-stack-client" },
@ -1581,6 +1582,7 @@ requires-dist = [
{ name = "tomli", marker = "extra == 'docs'" },
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "transformers", marker = "extra == 'test'" },
{ name = "types-requests", marker = "extra == 'dev'" },
{ name = "types-setuptools", marker = "extra == 'dev'" },
{ name = "uvicorn", marker = "extra == 'dev'" },
@ -3417,6 +3419,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e8/a8/d71f44b93e3aa86ae232af1f2126ca7b95c0f515ec135462b3e1f351441c/ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a", size = 10177499, upload-time = "2025-02-10T12:59:42.989Z" },
]
[[package]]
name = "safetensors"
version = "0.5.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/71/7e/2d5d6ee7b40c0682315367ec7475693d110f512922d582fef1bd4a63adc3/safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965", size = 67210 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/ae/88f6c49dbd0cc4da0e08610019a3c78a7d390879a919411a410a1876d03a/safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073", size = 436917 },
{ url = "https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7", size = 418419 },
{ url = "https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467", size = 459493 },
{ url = "https://files.pythonhosted.org/packages/df/5c/bf2cae92222513cc23b3ff85c4a1bb2811a2c3583ac0f8e8d502751de934/safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e", size = 472400 },
{ url = "https://files.pythonhosted.org/packages/58/11/7456afb740bd45782d0f4c8e8e1bb9e572f1bf82899fb6ace58af47b4282/safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d", size = 522891 },
{ url = "https://files.pythonhosted.org/packages/57/3d/fe73a9d2ace487e7285f6e157afee2383bd1ddb911b7cb44a55cf812eae3/safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9", size = 537694 },
{ url = "https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a", size = 471642 },
{ url = "https://files.pythonhosted.org/packages/ce/20/1fbe16f9b815f6c5a672f5b760951e20e17e43f67f231428f871909a37f6/safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d", size = 502241 },
{ url = "https://files.pythonhosted.org/packages/5f/18/8e108846b506487aa4629fe4116b27db65c3dde922de2c8e0cc1133f3f29/safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b", size = 638001 },
{ url = "https://files.pythonhosted.org/packages/82/5a/c116111d8291af6c8c8a8b40628fe833b9db97d8141c2a82359d14d9e078/safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff", size = 734013 },
{ url = "https://files.pythonhosted.org/packages/7d/ff/41fcc4d3b7de837963622e8610d998710705bbde9a8a17221d85e5d0baad/safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135", size = 670687 },
{ url = "https://files.pythonhosted.org/packages/40/ad/2b113098e69c985a3d8fbda4b902778eae4a35b7d5188859b4a63d30c161/safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04", size = 643147 },
{ url = "https://files.pythonhosted.org/packages/0a/0c/95aeb51d4246bd9a3242d3d8349c1112b4ee7611a4b40f0c5c93b05f001d/safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace", size = 296677 },
{ url = "https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11", size = 308878 },
]
[[package]]
name = "setuptools"
version = "75.8.0"
@ -3833,6 +3857,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669, upload-time = "2025-02-14T06:02:47.341Z" },
]
[[package]]
name = "tokenizers"
version = "0.21.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
]
sdist = { url = "https://files.pythonhosted.org/packages/92/76/5ac0c97f1117b91b7eb7323dcd61af80d72f790b4df71249a7850c195f30/tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab", size = 343256 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a5/1f/328aee25f9115bf04262e8b4e5a2050b7b7cf44b59c74e982db7270c7f30/tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41", size = 2780767 },
{ url = "https://files.pythonhosted.org/packages/ae/1a/4526797f3719b0287853f12c5ad563a9be09d446c44ac784cdd7c50f76ab/tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3", size = 2650555 },
{ url = "https://files.pythonhosted.org/packages/4d/7a/a209b29f971a9fdc1da86f917fe4524564924db50d13f0724feed37b2a4d/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f", size = 2937541 },
{ url = "https://files.pythonhosted.org/packages/3c/1e/b788b50ffc6191e0b1fc2b0d49df8cff16fe415302e5ceb89f619d12c5bc/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf", size = 2819058 },
{ url = "https://files.pythonhosted.org/packages/36/aa/3626dfa09a0ecc5b57a8c58eeaeb7dd7ca9a37ad9dd681edab5acd55764c/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8", size = 3133278 },
{ url = "https://files.pythonhosted.org/packages/a4/4d/8fbc203838b3d26269f944a89459d94c858f5b3f9a9b6ee9728cdcf69161/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0", size = 3144253 },
{ url = "https://files.pythonhosted.org/packages/d8/1b/2bd062adeb7c7511b847b32e356024980c0ffcf35f28947792c2d8ad2288/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c", size = 3398225 },
{ url = "https://files.pythonhosted.org/packages/8a/63/38be071b0c8e06840bc6046991636bcb30c27f6bb1e670f4f4bc87cf49cc/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a", size = 3038874 },
{ url = "https://files.pythonhosted.org/packages/ec/83/afa94193c09246417c23a3c75a8a0a96bf44ab5630a3015538d0c316dd4b/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf", size = 9014448 },
{ url = "https://files.pythonhosted.org/packages/ae/b3/0e1a37d4f84c0f014d43701c11eb8072704f6efe8d8fc2dcdb79c47d76de/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6", size = 8937877 },
{ url = "https://files.pythonhosted.org/packages/ac/33/ff08f50e6d615eb180a4a328c65907feb6ded0b8f990ec923969759dc379/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d", size = 9186645 },
{ url = "https://files.pythonhosted.org/packages/5f/aa/8ae85f69a9f6012c6f8011c6f4aa1c96154c816e9eea2e1b758601157833/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f", size = 9384380 },
{ url = "https://files.pythonhosted.org/packages/e8/5b/a5d98c89f747455e8b7a9504910c865d5e51da55e825a7ae641fb5ff0a58/tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3", size = 2239506 },
{ url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
]
[[package]]
name = "toml"
version = "0.10.2"
@ -4043,6 +4092,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" },
]
[[package]]
name = "transformers"
version = "4.50.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "regex" },
{ name = "requests" },
{ name = "safetensors" },
{ name = "tokenizers" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/29/37877123d6633a188997d75dc17d6f526745d63361794348ce748db23d49/transformers-4.50.3.tar.gz", hash = "sha256:1d795d24925e615a8e63687d077e4f7348c2702eb87032286eaa76d83cdc684f", size = 8774363 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/aa/22/733a6fc4a6445d835242f64c490fdd30f4a08d58f2b788613de3f9170692/transformers-4.50.3-py3-none-any.whl", hash = "sha256:6111610a43dec24ef32c3df0632c6b25b07d9711c01d9e1077bdd2ff6b14a38c", size = 10180411 },
]
[[package]]
name = "types-requests"
version = "2.32.0.20241016"