add basic integration test

This commit is contained in:
Ashwin Bharambe 2025-05-20 18:20:16 -07:00
parent 6e57929ede
commit e6ddf5dac7
43 changed files with 342 additions and 44 deletions

View file

@ -106,10 +106,10 @@ class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
principal = get_principal() principal = get_principal()
# check that provider_id is registered # check that provider_id is registered
run_config = self.deps[Api.inspect].run_config run_config = self.deps[Api.inspect].config.run_config
# TODO: we should make provider_ids unique across all APIs which is not enforced yet # TODO: we should make provider_ids unique across all APIs which is not enforced yet
provider_ids = [p.provider_id for p in run_config.providers.values()] provider_ids = [p.provider_id for plist in run_config.providers.values() for p in plist]
if provider_id not in provider_ids: if provider_id not in provider_ids:
raise ValueError(f"Provider {provider_id} is not registered") raise ValueError(f"Provider {provider_id} is not registered")

View file

@ -9,6 +9,7 @@ from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.credentials import Credentials
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
@ -61,6 +62,7 @@ class InvalidProviderError(Exception):
def api_protocol_map() -> dict[Api, Any]: def api_protocol_map() -> dict[Api, Any]:
return { return {
Api.providers: ProvidersAPI, Api.providers: ProvidersAPI,
Api.credentials: Credentials,
Api.agents: Agents, Api.agents: Agents,
Api.inference: Inference, Api.inference: Inference,
Api.inspect: Inspect, Api.inspect: Inspect,

View file

@ -41,9 +41,11 @@ from llama_stack.distribution.inspect import DistributionInspectConfig, Distribu
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -214,8 +216,12 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun
) )
await providers_impl.initialize() await providers_impl.initialize()
# TODO: make metadata_store and credentials_store non-optional by including it in the templates
credentials_impl = DistributionCredentialsImpl( credentials_impl = DistributionCredentialsImpl(
DistributionCredentialsConfig(kvstore=run_config.credentials_store), DistributionCredentialsConfig(
kvstore=run_config.credentials_store
or SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "credentials.db").as_posix())
),
deps=impls, deps=impls,
) )
await credentials_impl.initialize() await credentials_impl.initialize()
@ -231,18 +237,26 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]: ) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) kvstore_config = run_config.metadata_store or SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "kvstore.db").as_posix()
)
dist_registry, _ = await create_dist_registry(kvstore_config)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
internal_impls = await instantiate_internal_impls(impls, run_config) internal_impls = await instantiate_internal_impls(impls, run_config)
impls.update(internal_impls) impls.update(internal_impls)
# credentials_store = internal_impls[Api.credentials] # HACK: this is a hack to work around circular dependency issues. we probably need to
# for impl in impls.values(): # make resolving internal implementations be part of `resolve_impls` again (as it used to be
# # in an ideal world, we would pass the credentials store as a dependency # a while ago) so that dependencies can be expressed properly.
# if hasattr(impl, "credentials_store"): for impl in impls.values():
# impl.credentials_store = credentials_store from llama_stack.distribution.routers.routing_tables import CommonRoutingTableImpl
if isinstance(impl, CommonRoutingTableImpl):
for provider_impl in impl.impls_by_provider_id.values():
if hasattr(provider_impl, "credentials_store"):
provider_impl.credentials_store = internal_impls[Api.credentials]
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls

View file

@ -11,10 +11,8 @@ from typing import Protocol
import pydantic import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core") logger = get_logger(__name__, category="core")
@ -189,16 +187,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry( async def create_dist_registry(
metadata_store: KVStoreConfig | None, kvstore_config: KVStoreConfig,
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]: ) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata dist_kvstore = await kvstore_impl(kvstore_config)
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore) dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize() await dist_registry.initialize()
return dist_registry, dist_kvstore return dist_registry, dist_kvstore

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Protocol from typing import Any, Protocol, runtime_checkable
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -112,6 +112,7 @@ class ProviderSpec(BaseModel):
return self.provider_type in ("sample", "remote::sample") return self.provider_type in ("sample", "remote::sample")
@runtime_checkable
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -83,6 +83,5 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"], pip_packages=["mcp"],
), ),
api_dependencies=[Api.credentials],
), ),
] ]

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator from contextlib import asynccontextmanager
from typing import Any, cast from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import exceptiongroup import exceptiongroup
import httpx import httpx
from mcp import ClientSession from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
@ -23,12 +24,16 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig from .config import ModelContextProtocolConfig
logger = get_logger(__name__, category="tools")
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try: try:
async with sse_client(endpoint, headers=headers) as streams: async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
@ -48,9 +53,13 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]): # HACK: this is filled in by the Stack resolver magically right now to work around
# circular dependency issues.
credentials_store: CredentialsStore
def __init__(self, config: ModelContextProtocolConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
self.credentials_store = cast(CredentialsStore, deps[Api.credentials]) self.credentials_store = None
async def initialize(self): async def initialize(self):
pass pass
@ -99,14 +108,27 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client_wrapper(endpoint, headers) as session: async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs) result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult( return ToolInvocationResult(
content=[result.model_dump_json() for result in result.content], content=content,
error_code=1 if result.isError else 0, error_code=1 if result.isError else 0,
) )
async def get_headers(self) -> dict[str, str]: async def get_headers(self) -> dict[str, str]:
if self.credentials_store is None:
raise ValueError("credentials_store is not set")
headers = {} headers = {}
credentials = await self.credentials_store.get_credential(self.__provider_id__) token = await self.credentials_store.read_decrypted_credential(self.__provider_id__)
if credentials: if token:
headers["Authorization"] = f"Bearer {credentials.token}" headers["Authorization"] = f"Bearer {token}"
return headers return headers

View file

@ -96,6 +96,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta.llama3-1-8b-instruct-v1:0 model_id: meta.llama3-1-8b-instruct-v1:0

View file

@ -99,6 +99,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: llama3.1-8b model_id: llama3.1-8b

View file

@ -99,6 +99,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -99,6 +99,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -95,6 +95,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -106,6 +106,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -99,6 +99,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: groq/llama3-8b-8192 model_id: groq/llama3-8b-8192

View file

@ -107,6 +107,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -102,6 +102,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -107,6 +107,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -102,6 +102,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: Llama-3.3-70B-Instruct model_id: Llama-3.3-70B-Instruct

View file

@ -117,6 +117,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -107,6 +107,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -92,6 +92,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -80,6 +80,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta/llama3-8b-instruct model_id: meta/llama3-8b-instruct

View file

@ -112,6 +112,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -110,6 +110,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -125,6 +125,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o

View file

@ -111,6 +111,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct model_id: meta-llama/Llama-3.1-8B-Instruct

View file

@ -106,6 +106,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct model_id: meta-llama/Llama-3.1-8B-Instruct

View file

@ -115,6 +115,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -108,6 +108,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -81,6 +81,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_id: sambanova/Meta-Llama-3.1-8B-Instruct

View file

@ -133,6 +133,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o

View file

@ -117,6 +117,10 @@ class RunConfigSettings(BaseModel):
__distro_dir__=f"~/.llama/distributions/{name}", __distro_dir__=f"~/.llama/distributions/{name}",
db_name="registry.db", db_name="registry.db",
), ),
credentials_store=SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="credentials.db",
),
models=self.default_models or [], models=self.default_models or [],
shields=self.default_shields or [], shields=self.default_shields or [],
tool_groups=self.default_tool_groups or [], tool_groups=self.default_tool_groups or [],

View file

@ -102,6 +102,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -101,6 +101,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo

View file

@ -106,6 +106,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo

View file

@ -135,6 +135,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o

View file

@ -106,6 +106,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}

View file

@ -103,6 +103,9 @@ providers:
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/credentials.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct model_id: meta-llama/llama-3-3-70b-instruct

View file

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
AUTH_TOKEN = "test-token"
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
def test_mcp_invocation(llama_stack_client, mcp_server):
port = mcp_server
test_toolgroup_id = "remote::mcptest"
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
llama_stack_client.credentials.create(
provider_id="model-context-protocol",
token_type="access_token",
token=AUTH_TOKEN,
ttl_seconds=100,
)
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
kwargs=dict(url="https://www.google.com"),
)
content = response.content
assert len(content) == 1
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
client=llama_stack_client,
model=model_id,
instructions="You are a helpful assistant.",
tools=[test_toolgroup_id],
)
session_id = agent.create_session("test-session")
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=False,
)
steps = response.steps
first = steps[0]
assert first.step_type == "inference"
assert len(first.api_model_response.tool_calls) == 1
tool_call = first.api_model_response.tool_calls[0]
assert tool_call.tool_name == "greet_everyone"
second = steps[1]
assert second.step_type == "tool_execution"
tool_response_content = second.tool_responses[0].content
assert len(tool_response_content) == 1
assert tool_response_content[0].type == "text"
assert tool_response_content[0].text == "Hello, world!"
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def test_toolsgroups_unregister(llama_stack_client):
client = llama_stack_client
client.toolgroups.unregister(
toolgroup_id="builtin::websearch",
)