mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge branch 'main' into use-openai-for-databricks
This commit is contained in:
commit
c8623607f5
31 changed files with 665 additions and 691 deletions
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
export MOCK_INFERENCE_MODEL=mock-inference
|
|
||||||
|
|
||||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
|
||||||
|
|
||||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
export LLAMA_STACK_WORKERS=4
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
|
@ -5,6 +5,7 @@ data:
|
||||||
image_name: kubernetes-benchmark-demo
|
image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- files
|
- files
|
||||||
- safety
|
- safety
|
||||||
|
@ -23,6 +24,14 @@ data:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
||||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
- name: VLLM_TLS_VERIFY
|
- name: VLLM_TLS_VERIFY
|
||||||
value: "false"
|
value: "false"
|
||||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
- name: LLAMA_STACK_LOGGING
|
||||||
|
value: "all=WARNING"
|
||||||
|
- name: LLAMA_STACK_CONFIG
|
||||||
|
value: "/etc/config/stack_run_config.yaml"
|
||||||
|
- name: LLAMA_STACK_WORKERS
|
||||||
|
value: "${LLAMA_STACK_WORKERS}"
|
||||||
|
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 8323
|
- containerPort: 8323
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
|
limits:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: llama-storage
|
- name: llama-storage
|
||||||
mountPath: /root/.llama
|
mountPath: /root/.llama
|
||||||
|
|
|
@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
||||||
embedding_model = embed_lm.identifier
|
embedding_model = embed_lm.identifier
|
||||||
vector_db_id = f"v{uuid.uuid4().hex}"
|
vector_db_id = f"v{uuid.uuid4().hex}"
|
||||||
client.vector_dbs.register(
|
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
||||||
|
# We capture the correct ID from the response's .identifier attribute.
|
||||||
|
vector_db_id = client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
)
|
).identifier
|
||||||
|
|
||||||
# Create Documents
|
# Create Documents
|
||||||
urls = [
|
urls = [
|
||||||
|
|
|
@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
If you want to use inline Milvus, you can install:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus[milvus-lite]
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use remote Milvus, you can install:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install pymilvus
|
pip install pymilvus
|
||||||
|
|
|
@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def pip_packages(self) -> list[str]:
|
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /models, /shields
|
# Example: /models, /shields
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
|
|
|
@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
|
||||||
|
|
||||||
|
|
||||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
adapter = AdapterSpec(**spec_data["adapter"])
|
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||||
spec = remote_provider_spec(
|
|
||||||
api=api,
|
|
||||||
adapter=adapter,
|
|
||||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
||||||
)
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
spec = InlineProviderSpec(
|
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||||
api=api,
|
|
||||||
provider_type=f"inline::{provider_name}",
|
|
||||||
pip_packages=spec_data.get("pip_packages", []),
|
|
||||||
module=spec_data["module"],
|
|
||||||
config_class=spec_data["config_class"],
|
|
||||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
||||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
|
||||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
|
||||||
container_image=spec_data.get("container_image"),
|
|
||||||
)
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
||||||
from llama_stack.core.resolver import ProviderRegistry
|
from llama_stack.core.resolver import ProviderRegistry
|
||||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
construct_stack,
|
Stack,
|
||||||
get_stack_run_config_from_distro,
|
get_stack_run_config_from_distro,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
|
||||||
|
stack = Stack(self.config, self.custom_provider_registry)
|
||||||
|
await stack.initialize()
|
||||||
|
self.impls = stack.impls
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
assert self.impls is not None
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.resolver import InvalidProviderError
|
|
||||||
from llama_stack.core.server.routes import (
|
from llama_stack.core.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
get_all_api_routes,
|
get_all_api_routes,
|
||||||
initialize_route_impls,
|
initialize_route_impls,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
|
Stack,
|
||||||
cast_image_name_to_string,
|
cast_image_name_to_string,
|
||||||
construct_stack,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
shutdown_stack,
|
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(app):
|
class StackApp(FastAPI):
|
||||||
"""Initiate a graceful shutdown of the application.
|
|
||||||
|
|
||||||
Handled by the lifespan context manager. The shutdown process involves
|
|
||||||
shutting down all implementations registered in the application.
|
|
||||||
"""
|
"""
|
||||||
await shutdown_stack(app.__llama_stack_impls__)
|
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||||
|
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.stack: Stack = Stack(config)
|
||||||
|
|
||||||
|
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||||
|
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||||
|
# function.
|
||||||
|
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||||
|
# in a separate thread.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: StackApp):
|
||||||
logger.info("Starting up")
|
logger.info("Starting up")
|
||||||
|
assert app.stack is not None
|
||||||
|
app.stack.create_registry_refresh_task()
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
await shutdown(app)
|
await app.stack.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace | None = None):
|
def create_app(
|
||||||
"""Start the LlamaStack server."""
|
config_file: str | None = None,
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
env_vars: list[str] | None = None,
|
||||||
|
) -> StackApp:
|
||||||
|
"""Create and configure the FastAPI application.
|
||||||
|
|
||||||
add_config_distro_args(parser)
|
Args:
|
||||||
parser.add_argument(
|
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||||
"--port",
|
env_vars: List of environment variables in KEY=value format.
|
||||||
type=int,
|
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
|
||||||
help="Port to listen on",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
Returns:
|
||||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
Configured StackApp instance.
|
||||||
# parsed from the command line
|
"""
|
||||||
if args is None:
|
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||||
args = parser.parse_args()
|
if config_file is None:
|
||||||
|
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||||
|
|
||||||
config_or_distro = get_config_from_args(args)
|
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
|
||||||
|
|
||||||
|
# Load and process configuration
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
|
||||||
for env_pair in args.env:
|
if env_vars:
|
||||||
|
for env_pair in env_vars:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
logger.info(f"Setting environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||||
|
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||||
|
|
||||||
_log_run_config(run_config=config)
|
_log_run_config(run_config=config)
|
||||||
|
|
||||||
app = FastAPI(
|
app = StackApp(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
openapi_url="/openapi.json",
|
openapi_url="/openapi.json",
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
impls = app.stack.impls
|
||||||
# Create and set the event loop that will be used for both construction and server runtime
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Construct the stack in the persistent event loop
|
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
|
||||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace | None = None):
|
||||||
|
"""Start the LlamaStack server."""
|
||||||
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
|
||||||
|
add_config_distro_args(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
|
help="Port to listen on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||||
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||||
|
# parsed from the command line
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_or_distro = get_config_from_args(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(
|
||||||
|
config_file=config_or_distro,
|
||||||
|
env_vars=args.env,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating app: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||||
|
with open(config_file) as fp:
|
||||||
|
config_contents = yaml.safe_load(fp)
|
||||||
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
else:
|
||||||
|
logger_config = None
|
||||||
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure SSL if certificates are provided
|
# Configure SSL if certificates are provided
|
||||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if ssl_config:
|
if ssl_config:
|
||||||
uvicorn_config.update(ssl_config)
|
uvicorn_config.update(ssl_config)
|
||||||
|
|
||||||
# Run uvicorn in the existing event loop to preserve background tasks
|
|
||||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||||
# signal handling but this is quite intrusive and not worth the effort.
|
# signal handling but this is quite intrusive and not worth the effort.
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||||
finally:
|
|
||||||
if not loop.is_closed():
|
|
||||||
logger.debug("Closing event loop")
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackRunConfig):
|
||||||
|
|
|
@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
|
||||||
|
class Stack:
|
||||||
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
|
self.run_config = run_config
|
||||||
|
self.provider_registry = provider_registry
|
||||||
|
self.impls = None
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(
|
async def initialize(self):
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
|
||||||
) -> dict[Api, Any]:
|
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
|
|
||||||
|
@ -329,24 +333,28 @@ async def construct_stack(
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
impls = await resolve_impls(
|
impls = await resolve_impls(
|
||||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
# Add internal implementations after all other providers are resolved
|
||||||
add_internal_implementations(impls, run_config)
|
add_internal_implementations(impls, self.run_config)
|
||||||
|
|
||||||
if Api.prompts in impls:
|
if Api.prompts in impls:
|
||||||
await impls[Api.prompts].initialize()
|
await impls[Api.prompts].initialize()
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
|
self.impls = impls
|
||||||
|
|
||||||
|
def create_registry_refresh_task(self):
|
||||||
|
assert self.impls is not None, "Must call initialize() before starting"
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||||
|
|
||||||
def cb(task):
|
def cb(task):
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -360,11 +368,9 @@ async def construct_stack(
|
||||||
logger.debug("Model refresh task completed")
|
logger.debug("Model refresh task completed")
|
||||||
|
|
||||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||||
return impls
|
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
async def shutdown_stack(impls: dict[Api, Any]):
|
for impl in self.impls.values():
|
||||||
for impl in impls.values():
|
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info(f"Shutting down {impl_name}")
|
logger.info(f"Shutting down {impl_name}")
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
|
||||||
remote_providers = [
|
remote_providers = [
|
||||||
provider
|
provider
|
||||||
for provider in available_providers()
|
for provider in available_providers()
|
||||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||||
]
|
]
|
||||||
|
|
||||||
inference_providers = []
|
inference_providers = []
|
||||||
for provider_spec in remote_providers:
|
for provider_spec in remote_providers:
|
||||||
provider_type = provider_spec.adapter.adapter_type
|
provider_type = provider_spec.adapter_type
|
||||||
|
|
||||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||||
|
|
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pip_packages: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_data_validator: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
||||||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
# TODO: this can now be inlined into RemoteProviderSpec
|
|
||||||
@json_schema_type
|
|
||||||
class AdapterSpec(BaseModel):
|
|
||||||
adapter_type: str = Field(
|
|
||||||
...,
|
|
||||||
description="Unique identifier for this adapter",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
|
||||||
default_factory=str,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
config_class: str = Field(
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
|
||||||
)
|
|
||||||
provider_data_validator: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
description: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
A description of the provider. This is used to display in the documentation.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InlineProviderSpec(ProviderSpec):
|
class InlineProviderSpec(ProviderSpec):
|
||||||
pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
container_image: str | None = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
||||||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
# module field is inherited from ProviderSpec
|
|
||||||
provider_data_validator: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
description: str | None = Field(
|
description: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: AdapterSpec = Field(
|
adapter_type: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for this adapter",
|
||||||
|
)
|
||||||
|
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
A description of the provider. This is used to display in the documentation.
|
||||||
API responses, specify the adapter here.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
||||||
def container_image(self) -> str | None:
|
def container_image(self) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# module field is inherited from ProviderSpec
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pip_packages(self) -> list[str]:
|
|
||||||
return self.adapter.pip_packages
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_data_validator(self) -> str | None:
|
|
||||||
return self.adapter.provider_data_validator
|
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(
|
|
||||||
api: Api,
|
|
||||||
adapter: AdapterSpec,
|
|
||||||
api_dependencies: list[Api] | None = None,
|
|
||||||
optional_api_dependencies: list[Api] | None = None,
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
return RemoteProviderSpec(
|
|
||||||
api=api,
|
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
|
||||||
config_class=adapter.config_class,
|
|
||||||
module=adapter.module,
|
|
||||||
adapter=adapter,
|
|
||||||
api_dependencies=api_dependencies or [],
|
|
||||||
optional_api_dependencies=optional_api_dependencies or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HealthStatus(StrEnum):
|
class HealthStatus(StrEnum):
|
||||||
OK = "OK"
|
OK = "OK"
|
||||||
|
|
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
||||||
storage_path.mkdir(parents=True, exist_ok=True)
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize SQL store for metadata
|
# Initialize SQL store for metadata
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
raise RuntimeError("Files provider not initialized")
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions if where_conditions else None,
|
where=where_conditions if where_conditions else None,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,10 +24,10 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.datasetio,
|
api=Api.datasetio,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="huggingface",
|
adapter_type="huggingface",
|
||||||
|
provider_type="remote::huggingface",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
],
|
],
|
||||||
|
@ -36,17 +35,15 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.datasetio,
|
api=Api.datasetio,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="nvidia",
|
adapter_type="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||||
|
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
|
||||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
|
||||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
],
|
],
|
||||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.eval,
|
api=Api.eval,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="nvidia",
|
adapter_type="nvidia",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"requests",
|
"requests",
|
||||||
],
|
],
|
||||||
|
provider_type="remote::nvidia",
|
||||||
module="llama_stack.providers.remote.eval.nvidia",
|
module="llama_stack.providers.remote.eval.nvidia",
|
||||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||||
),
|
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
|
|
|
@ -4,13 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
|
||||||
InlineProviderSpec,
|
|
||||||
ProviderSpec,
|
|
||||||
remote_provider_spec,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.files,
|
api=Api.files,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::s3",
|
||||||
adapter_type="s3",
|
adapter_type="s3",
|
||||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||||
module="llama_stack.providers.remote.files.s3",
|
module="llama_stack.providers.remote.files.s3",
|
||||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
META_REFERENCE_DEPS = [
|
META_REFERENCE_DEPS = [
|
||||||
|
@ -49,10 +48,10 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="cerebras",
|
adapter_type="cerebras",
|
||||||
|
provider_type="remote::cerebras",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"cerebras_cloud_sdk",
|
"cerebras_cloud_sdk",
|
||||||
],
|
],
|
||||||
|
@ -60,62 +59,56 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="ollama",
|
adapter_type="ollama",
|
||||||
|
provider_type="remote::ollama",
|
||||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||||
module="llama_stack.providers.remote.inference.ollama",
|
module="llama_stack.providers.remote.inference.ollama",
|
||||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="vllm",
|
adapter_type="vllm",
|
||||||
|
provider_type="remote::vllm",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.vllm",
|
module="llama_stack.providers.remote.inference.vllm",
|
||||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="tgi",
|
adapter_type="tgi",
|
||||||
|
provider_type="remote::tgi",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="hf::serverless",
|
adapter_type="hf::serverless",
|
||||||
|
provider_type="remote::hf::serverless",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::hf::endpoint",
|
||||||
adapter_type="hf::endpoint",
|
adapter_type="hf::endpoint",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="fireworks",
|
adapter_type="fireworks",
|
||||||
|
provider_type="remote::fireworks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"fireworks-ai<=0.17.16",
|
"fireworks-ai<=0.17.16",
|
||||||
],
|
],
|
||||||
|
@ -124,11 +117,10 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="together",
|
adapter_type="together",
|
||||||
|
provider_type="remote::together",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"together",
|
"together",
|
||||||
],
|
],
|
||||||
|
@ -137,85 +129,82 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="bedrock",
|
adapter_type="bedrock",
|
||||||
|
provider_type="remote::bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.remote.inference.bedrock",
|
module="llama_stack.providers.remote.inference.bedrock",
|
||||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="databricks",
|
adapter_type="databricks",
|
||||||
|
provider_type="remote::databricks",
|
||||||
pip_packages=["databricks-sdk"],
|
pip_packages=["databricks-sdk"],
|
||||||
module="llama_stack.providers.remote.inference.databricks",
|
module="llama_stack.providers.remote.inference.databricks",
|
||||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="nvidia",
|
adapter_type="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.nvidia",
|
module="llama_stack.providers.remote.inference.nvidia",
|
||||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="runpod",
|
adapter_type="runpod",
|
||||||
|
provider_type="remote::runpod",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.runpod",
|
module="llama_stack.providers.remote.inference.runpod",
|
||||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="openai",
|
adapter_type="openai",
|
||||||
|
provider_type="remote::openai",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.openai",
|
module="llama_stack.providers.remote.inference.openai",
|
||||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="anthropic",
|
adapter_type="anthropic",
|
||||||
|
provider_type="remote::anthropic",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.anthropic",
|
module="llama_stack.providers.remote.inference.anthropic",
|
||||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="gemini",
|
adapter_type="gemini",
|
||||||
pip_packages=["litellm"],
|
provider_type="remote::gemini",
|
||||||
|
pip_packages=[
|
||||||
|
"litellm",
|
||||||
|
],
|
||||||
module="llama_stack.providers.remote.inference.gemini",
|
module="llama_stack.providers.remote.inference.gemini",
|
||||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="vertexai",
|
adapter_type="vertexai",
|
||||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
provider_type="remote::vertexai",
|
||||||
|
pip_packages=[
|
||||||
|
"litellm",
|
||||||
|
"google-cloud-aiplatform",
|
||||||
|
],
|
||||||
module="llama_stack.providers.remote.inference.vertexai",
|
module="llama_stack.providers.remote.inference.vertexai",
|
||||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||||
|
@ -240,65 +229,63 @@ Available Models:
|
||||||
- vertex_ai/gemini-2.5-flash
|
- vertex_ai/gemini-2.5-flash
|
||||||
- vertex_ai/gemini-2.5-pro""",
|
- vertex_ai/gemini-2.5-pro""",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="groq",
|
adapter_type="groq",
|
||||||
pip_packages=["litellm"],
|
provider_type="remote::groq",
|
||||||
|
pip_packages=[
|
||||||
|
"litellm",
|
||||||
|
],
|
||||||
module="llama_stack.providers.remote.inference.groq",
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="llama-openai-compat",
|
adapter_type="llama-openai-compat",
|
||||||
|
provider_type="remote::llama-openai-compat",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="sambanova",
|
adapter_type="sambanova",
|
||||||
pip_packages=["litellm"],
|
provider_type="remote::sambanova",
|
||||||
|
pip_packages=[
|
||||||
|
"litellm",
|
||||||
|
],
|
||||||
module="llama_stack.providers.remote.inference.sambanova",
|
module="llama_stack.providers.remote.inference.sambanova",
|
||||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="passthrough",
|
adapter_type="passthrough",
|
||||||
|
provider_type="remote::passthrough",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.passthrough",
|
module="llama_stack.providers.remote.inference.passthrough",
|
||||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="watsonx",
|
adapter_type="watsonx",
|
||||||
|
provider_type="remote::watsonx",
|
||||||
pip_packages=["ibm_watsonx_ai"],
|
pip_packages=["ibm_watsonx_ai"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::azure",
|
||||||
adapter_type="azure",
|
adapter_type="azure",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.azure",
|
module="llama_stack.providers.remote.inference.azure",
|
||||||
|
@ -310,5 +297,4 @@ Provider documentation
|
||||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
],
|
],
|
||||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.post_training,
|
api=Api.post_training,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="nvidia",
|
adapter_type="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
pip_packages=["requests", "aiohttp"],
|
pip_packages=["requests", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.post_training.nvidia",
|
module="llama_stack.providers.remote.post_training.nvidia",
|
||||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="bedrock",
|
adapter_type="bedrock",
|
||||||
|
provider_type="remote::bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.remote.safety.bedrock",
|
module="llama_stack.providers.remote.safety.bedrock",
|
||||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="nvidia",
|
adapter_type="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
module="llama_stack.providers.remote.safety.nvidia",
|
module="llama_stack.providers.remote.safety.nvidia",
|
||||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="sambanova",
|
adapter_type="sambanova",
|
||||||
|
provider_type="remote::sambanova",
|
||||||
pip_packages=["litellm", "requests"],
|
pip_packages=["litellm", "requests"],
|
||||||
module="llama_stack.providers.remote.safety.sambanova",
|
module="llama_stack.providers.remote.safety.sambanova",
|
||||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="brave-search",
|
adapter_type="brave-search",
|
||||||
|
provider_type="remote::brave-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="bing-search",
|
adapter_type="bing-search",
|
||||||
|
provider_type="remote::bing-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="tavily-search",
|
adapter_type="tavily-search",
|
||||||
|
provider_type="remote::tavily-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="wolfram-alpha",
|
adapter_type="wolfram-alpha",
|
||||||
|
provider_type="remote::wolfram-alpha",
|
||||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||||
),
|
),
|
||||||
),
|
RemoteProviderSpec(
|
||||||
remote_provider_spec(
|
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="model-context-protocol",
|
adapter_type="model-context-protocol",
|
||||||
|
provider_type="remote::model-context-protocol",
|
||||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||||
pip_packages=["mcp>=1.8.1"],
|
pip_packages=["mcp>=1.8.1"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||||
),
|
),
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -300,13 +299,15 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
||||||
Please refer to the sqlite-vec provider documentation.
|
Please refer to the sqlite-vec provider documentation.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
pip_packages=["chromadb-client"],
|
pip_packages=["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
description="""
|
description="""
|
||||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||||
|
@ -341,9 +342,6 @@ pip install chromadb
|
||||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::chromadb",
|
provider_type="inline::chromadb",
|
||||||
|
@ -387,13 +385,15 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
||||||
|
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="pgvector",
|
adapter_type="pgvector",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
pip_packages=["psycopg2-binary"],
|
pip_packages=["psycopg2-binary"],
|
||||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
description="""
|
description="""
|
||||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||||
allows you to store and query vectors directly in memory.
|
allows you to store and query vectors directly in memory.
|
||||||
|
@ -496,17 +496,16 @@ docker pull pgvector/pgvector:pg17
|
||||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
RemoteProviderSpec(
|
||||||
optional_api_dependencies=[Api.files],
|
api=Api.vector_io,
|
||||||
),
|
|
||||||
remote_provider_spec(
|
|
||||||
Api.vector_io,
|
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
pip_packages=["weaviate-client"],
|
pip_packages=["weaviate-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
description="""
|
description="""
|
||||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||||
It allows you to store and query vectors directly within a Weaviate database.
|
It allows you to store and query vectors directly within a Weaviate database.
|
||||||
|
@ -539,9 +538,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
||||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::qdrant",
|
provider_type="inline::qdrant",
|
||||||
|
@ -594,27 +590,28 @@ docker pull qdrant/qdrant
|
||||||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="qdrant",
|
adapter_type="qdrant",
|
||||||
|
provider_type="remote::qdrant",
|
||||||
pip_packages=["qdrant-client"],
|
pip_packages=["qdrant-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
description="""
|
description="""
|
||||||
Please refer to the inline provider documentation.
|
Please refer to the inline provider documentation.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
RemoteProviderSpec(
|
||||||
optional_api_dependencies=[Api.files],
|
api=Api.vector_io,
|
||||||
),
|
|
||||||
remote_provider_spec(
|
|
||||||
Api.vector_io,
|
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="milvus",
|
adapter_type="milvus",
|
||||||
|
provider_type="remote::milvus",
|
||||||
pip_packages=["pymilvus>=2.4.10"],
|
pip_packages=["pymilvus>=2.4.10"],
|
||||||
module="llama_stack.providers.remote.vector_io.milvus",
|
module="llama_stack.providers.remote.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
description="""
|
description="""
|
||||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||||
allows you to store and query vectors directly within a Milvus database.
|
allows you to store and query vectors directly within a Milvus database.
|
||||||
|
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
If you want to use inline Milvus, you can install:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus[milvus-lite]
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use remote Milvus, you can install:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install pymilvus
|
pip install pymilvus
|
||||||
|
@ -807,13 +810,10 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
||||||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::milvus",
|
provider_type="inline::milvus",
|
||||||
pip_packages=["pymilvus>=2.4.10"],
|
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||||
module="llama_stack.providers.inline.vector_io.milvus",
|
module="llama_stack.providers.inline.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
||||||
where: dict[str, str | dict] = {"id": file_id}
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
if not return_expired:
|
if not return_expired:
|
||||||
where["expires_at"] = {">": self._now()}
|
where["expires_at"] = {">": self._now()}
|
||||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
||||||
self._client = _create_s3_client(self._config)
|
self._client = _create_s3_client(self._config)
|
||||||
await _create_bucket_if_not_exists(self._client, self._config)
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
await self._sql_store.create_table(
|
await self._sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions,
|
where=where_conditions,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
table="chat_completions",
|
table="chat_completions",
|
||||||
where={"id": completion_id},
|
where={"id": completion_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||||
self.policy = policy
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
where={"id": response_id},
|
where={"id": response_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||||
|
|
||||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ValueError(f"Response with id {response_id} not found")
|
raise ValueError(f"Response with id {response_id} not found")
|
||||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
||||||
access control policies, user attribute capture, and SQL filtering optimization.
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sql_store: SqlStore):
|
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||||
"""
|
"""
|
||||||
Initialize the authorization layer.
|
Initialize the authorization layer.
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
:param policy: Access control policy to use for authorization
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self.policy = policy
|
||||||
self._detect_database_type()
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Fetch all rows with automatic access control filtering."""
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
access_where = self._build_access_control_where_clause(policy)
|
access_where = self._build_access_control_where_clause(self.policy)
|
||||||
rows = await self.sql_store.fetch_all(
|
rows = await self.sql_store.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
where=where,
|
where=where,
|
||||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
||||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||||
filtered_rows.append(row)
|
filtered_rows.append(row)
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_one(
|
async def fetch_one(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch one row with automatic access control checking."""
|
"""Fetch one row with automatic access control checking."""
|
||||||
results = await self.fetch_all(
|
results = await self.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
policy=policy,
|
|
||||||
where=where,
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
|
|
1
tests/external/kaze.yaml
vendored
1
tests/external/kaze.yaml
vendored
|
@ -1,4 +1,3 @@
|
||||||
adapter:
|
|
||||||
adapter_type: kaze
|
adapter_type: kaze
|
||||||
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
||||||
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.weather,
|
api=Api.weather,
|
||||||
provider_type="remote::kaze",
|
provider_type="remote::kaze",
|
||||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||||
adapter=AdapterSpec(
|
|
||||||
adapter_type="kaze",
|
adapter_type="kaze",
|
||||||
module="llama_stack_provider_kaze",
|
module="llama_stack_provider_kaze",
|
||||||
pip_packages=["llama_stack_provider_kaze"],
|
pip_packages=["llama_stack_provider_kaze"],
|
||||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ def authorized_store(backend_config):
|
||||||
config = config_func()
|
config = config_func()
|
||||||
|
|
||||||
base_sqlstore = sqlstore_impl(config)
|
base_sqlstore = sqlstore_impl(config)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
yield authorized_store
|
yield authorized_store
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
# Test fetching with no user - should not error on JSON comparison
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
assert result.data[0]["access_attributes"] is None
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
# Fetch all - admin should see both
|
# Fetch all - admin should see both
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
|
||||||
# Test with non-admin user
|
# Test with non-admin user
|
||||||
|
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
# Should only see public record
|
# Should only see public record
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
|
|
||||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
|
|
||||||
# Should see:
|
# Should see:
|
||||||
# - public record (1) - no access_attributes
|
# - public record (1) - no access_attributes
|
||||||
|
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Create a new authorized store with the owner-only policy
|
||||||
|
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
|
||||||
|
|
||||||
# Test user1 access - should only see their own record
|
# Test user1 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user1
|
mock_get_authenticated_user.return_value = user1
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test user2 access - should only see their own record
|
# Test user2 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user2
|
mock_get_authenticated_user.return_value = user2
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test with anonymous user - should see no records
|
# Test with anonymous user - should see no records
|
||||||
mock_get_authenticated_user.return_value = None
|
mock_get_authenticated_user.return_value = None
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -66,7 +66,6 @@ def base_config(tmp_path):
|
||||||
def provider_spec_yaml():
|
def provider_spec_yaml():
|
||||||
"""Common provider spec YAML for testing."""
|
"""Common provider spec YAML for testing."""
|
||||||
return """
|
return """
|
||||||
adapter:
|
|
||||||
adapter_type: test_provider
|
adapter_type: test_provider
|
||||||
config_class: test_provider.config.TestProviderConfig
|
config_class: test_provider.config.TestProviderConfig
|
||||||
module: test_provider
|
module: test_provider
|
||||||
|
@ -182,9 +181,9 @@ class TestProviderRegistry:
|
||||||
assert Api.inference in registry
|
assert Api.inference in registry
|
||||||
assert "remote::test_provider" in registry[Api.inference]
|
assert "remote::test_provider" in registry[Api.inference]
|
||||||
provider = registry[Api.inference]["remote::test_provider"]
|
provider = registry[Api.inference]["remote::test_provider"]
|
||||||
assert provider.adapter.adapter_type == "test_provider"
|
assert provider.adapter_type == "test_provider"
|
||||||
assert provider.adapter.module == "test_provider"
|
assert provider.module == "test_provider"
|
||||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||||
assert Api.safety in provider.api_dependencies
|
assert Api.safety in provider.api_dependencies
|
||||||
|
|
||||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||||
|
@ -246,7 +245,6 @@ class TestProviderRegistry:
|
||||||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||||
remote_dir, _ = api_directories
|
remote_dir, _ = api_directories
|
||||||
malformed_spec = """
|
malformed_spec = """
|
||||||
adapter:
|
|
||||||
adapter_type: test_provider
|
adapter_type: test_provider
|
||||||
# Missing required fields
|
# Missing required fields
|
||||||
api_dependencies:
|
api_dependencies:
|
||||||
|
@ -270,7 +268,7 @@ pip_packages:
|
||||||
with open(inline_dir / "malformed.yaml", "w") as f:
|
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||||
f.write(malformed_spec)
|
f.write(malformed_spec)
|
||||||
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
get_provider_registry(base_config)
|
get_provider_registry(base_config)
|
||||||
assert "config_class" in str(exc_info.value)
|
assert "config_class" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
# Create table with access control
|
# Create table with access control
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
mock_get_authenticated_user.return_value = admin_user
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
# Admin should see both documents
|
# Admin should see both documents
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "Admin Document"
|
assert result.data[0]["title"] == "Admin Document"
|
||||||
|
|
||||||
# User should only see their document
|
# User should only see their document
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "User Document"
|
assert result.data[0]["title"] == "User Document"
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||||
assert row is None
|
assert row is None
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
table="resources",
|
table="resources",
|
||||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||||
mock_get_authenticated_user.return_value = user
|
mock_get_authenticated_user.return_value = user
|
||||||
|
|
||||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
sql_results = await sqlstore.fetch_all("resources")
|
||||||
sql_ids = {row["id"] for row in sql_results.data}
|
sql_ids = {row["id"] for row in sql_results.data}
|
||||||
policy_ids = set()
|
policy_ids = set()
|
||||||
for scenario in test_scenarios:
|
for scenario in test_scenarios:
|
||||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await authorized_store.create_table(
|
await authorized_store.create_table(
|
||||||
table="user_data",
|
table="user_data",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue