Merge branch 'main' into sambanova-inferene

This commit is contained in:
snova-edwardm 2025-01-14 10:04:52 -08:00 committed by GitHub
commit 89ab2be302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
385 changed files with 39001 additions and 9280 deletions

View file

@ -4,22 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import logging
from enum import Enum
from typing import List
import pkg_resources
from pydantic import BaseModel
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
@ -37,6 +39,7 @@ SERVER_DEPENDENCIES = [
class ImageType(Enum):
docker = "docker"
conda = "conda"
venv = "venv"
class ApiInput(BaseModel):
@ -45,7 +48,7 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
config_providers: Dict[str, List[Provider]],
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
@ -90,11 +93,12 @@ def get_provider_dependencies(
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
cprint(
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
"yellow",
)
for special_dep in special_deps:
log.info(f"\tpip install {special_dep}")
cprint(f"pip install {special_dep}", "yellow")
print()
@ -107,8 +111,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_container.sh"
script = (
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
)
args = [
script,
@ -118,9 +122,9 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(normal_deps),
]
else:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_conda_env.sh"
elif build_config.image_type == ImageType.conda.value:
script = (
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
)
args = [
script,
@ -128,6 +132,14 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
str(build_file_path),
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
script = importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
args = [
script,
build_config.name,
str(build_file_path),
" ".join(normal_deps),
]
if special_deps:
args.append("#".join(special_deps))

View file

@ -83,7 +83,9 @@ ensure_conda_env_python310() {
# these packages are damaged in test-pypi, so install them first
$CONDA_PREFIX/bin/pip install fastapi libcst
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
llama-models==$TEST_PYPI_VERSION \
llama-stack-client==$TEST_PYPI_VERSION \
llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"

View file

@ -51,7 +51,19 @@ add_to_docker() {
fi
}
add_to_docker <<EOF
# Update and install UBI9 components if UBI9 base image is used
if [[ $docker_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
EOF
else
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
@ -64,6 +76,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
EOF
fi
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
@ -126,7 +139,7 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--templat
EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile\n\n"
cat $TEMP_DIR/Dockerfile
printf "\n"

View file

@ -0,0 +1,105 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: combine this with build_conda_env.sh since it is almost identical
# the only difference is that we don't do any conda-specific setup
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <build_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$4"
set -euo pipefail
build_name="$1"
env_name="llamastack-$build_name"
build_file_path="$2"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
echo "$part"
pip install $part
done
fi
else
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
pip uninstall -y llama-models
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
printf "Installing pip dependencies\n"
pip install $pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
echo "$part"
pip install $part
done
fi
fi
}
run "$env_name" "$pip_dependencies" "$special_pip_deps"

View file

@ -6,10 +6,14 @@
import logging
import textwrap
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from typing import Any, Dict
from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
@ -17,10 +21,7 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)

View file

@ -4,23 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
@ -37,6 +38,8 @@ RoutableObject = Union[
Dataset,
ScoringFn,
EvalTask,
Tool,
ToolGroup,
]
@ -48,6 +51,8 @@ RoutableObjectWithProvider = Annotated[
Dataset,
ScoringFn,
EvalTask,
Tool,
ToolGroup,
],
Field(discriminator="type"),
]
@ -59,6 +64,7 @@ RoutedProtocol = Union[
DatasetIO,
Scoring,
Eval,
ToolRuntime,
]
@ -155,6 +161,7 @@ a default SQLite store will be used.""",
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
class BuildConfig(BaseModel):
@ -165,5 +172,5 @@ class BuildConfig(BaseModel):
)
image_type: str = Field(
default="conda",
description="Type of package to build (conda | container)",
description="Type of package to build (conda | docker | venv)",
)

View file

@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.eval_tasks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
]

View file

@ -4,13 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from importlib.metadata import version
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ProviderInfo,
RouteInfo,
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class DistributionInspectConfig(BaseModel):
@ -65,3 +72,6 @@ class DistributionInspectImpl(Inspect):
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))

View file

@ -0,0 +1,456 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import logging
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx
import yaml
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
T = TypeVar("T")
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
await end_trace()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize())
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
provider_data=self.provider_data,
)
else:
async def _traced_request():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)
finally:
await end_trace()
return asyncio.run(_traced_request())
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(
sink for sink in current_sinks if sink != "console"
)
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# template
config = get_stack_run_config_from_template(config_path_or_template_name)
self.config_path_or_template_name = config_path_or_template_name
self.config = config
self.custom_provider_registry = custom_provider_registry
async def initialize(self):
try:
self.impls = await construct_stack(
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",
)
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
)
return False
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func
self.endpoint_impls = endpoint_impls
return True
async def request(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
return await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
async def _call_non_streaming(
self,
*,
cast_to: Any,
options: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body:
return {}
func = self.endpoint_impls[path]
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
return converted_body

View file

@ -40,8 +40,8 @@ class NeedsRequestProviderData:
def set_request_provider_data(headers: Dict[str, str]):
keys = [
"X-LlamaStack-ProviderData",
"x-llamastack-providerdata",
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
for key in keys:
val = headers.get(key, None)

View file

@ -5,14 +5,8 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
@ -24,15 +18,37 @@ from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,
EvalTasksProtocolPrivate,
InlineProviderSpec,
MemoryBanksProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
)
log = logging.getLogger(__name__)
@ -58,12 +74,16 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from .routing_tables import (
DatasetsRoutingTable,
@ -17,6 +18,7 @@ from .routing_tables import (
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
ToolGroupsRoutingTable,
)
@ -33,6 +35,7 @@ async def get_routing_table_impl(
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
}
if api.value not in api_to_tables:
@ -51,6 +54,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
MemoryRouter,
SafetyRouter,
ScoringRouter,
ToolRuntimeRouter,
)
api_to_routers = {
@ -60,6 +64,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -6,15 +6,40 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.models import ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
class MemoryRouter(Memory):
@ -59,7 +84,7 @@ class MemoryRouter(Memory):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
@ -88,9 +113,10 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata
model_id, provider_model_id, provider_id, metadata, model_type
)
async def chat_completion(
@ -101,10 +127,17 @@ class InferenceRouter(Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict(
model_id=model_id,
messages=messages,
@ -125,12 +158,19 @@ class InferenceRouter(Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -148,8 +188,15 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
@ -307,7 +354,6 @@ class EvalRouter(Eval):
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
@ -350,3 +396,30 @@ class EvalRouter(Eval):
task_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -6,22 +6,34 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_models.llama3.api.datatypes import URL
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.memory_banks import (
BankParams,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import Api, RoutingTable
def get_impl_api(p: Any) -> Api:
@ -30,7 +42,6 @@ def get_impl_api(p: Any) -> Api:
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -47,6 +58,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_eval_task(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -59,6 +72,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -76,7 +91,6 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
@ -107,6 +121,8 @@ class CommonRoutingTableImpl(RoutingTable):
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -128,6 +144,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")
@ -209,6 +227,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
@ -222,11 +241,18 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
@ -298,16 +324,36 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
if params.embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
await self.register_object(memory_bank)
return memory_bank
@ -436,3 +482,80 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
provider_resource_id=provider_eval_task_id,
)
await self.register_object(eval_task)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -16,6 +16,8 @@ import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
@ -28,25 +30,29 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.telemetry.meta_reference import (
TelemetryAdapter,
TelemetryConfig,
)
from .endpoints import get_all_api_endpoints
@ -217,13 +223,59 @@ class TracingMiddleware:
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -235,7 +287,12 @@ def main():
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
help="Port to listen on",
)
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
@ -277,10 +334,12 @@ def main():
config = StackRunConfig(**config)
print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2))
safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))

View file

@ -4,43 +4,40 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import logging
import os
from pathlib import Path
from typing import Any, Dict
import re
from typing import Any, Dict, Optional
import pkg_resources
import yaml
from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -65,6 +62,8 @@ class LlamaStack(
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
):
pass
@ -81,6 +80,7 @@ RESOURCES = [
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]
@ -112,6 +112,26 @@ class EnvVarError(Exception):
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
@ -190,14 +210,13 @@ async def construct_stack(
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template}/run.yaml"
template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
if not Path(template_path).exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
with open(template_path) as f:
run_config = yaml.safe_load(f)
with importlib.resources.as_file(template_path) as path:
if not path.exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))

View file

@ -90,7 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
$docker_image:$version_tag \
python -m llama_stack.distribution.server.server \
--yaml-config /app/config.yaml \
--port "$port"
--env LLAMA_STACK_PORT=$port \
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
$docker_image:$version_tag

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
@ -13,12 +12,8 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class DistributionRegistry(Protocol):
@ -40,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2"
KEY_VERSION = "v5"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
@ -54,10 +49,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(value),
)
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
return all_objects
@ -89,14 +81,7 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
objects_data = json.loads(json_str)
# Return only the first object if any exist
if objects_data:
return pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(objects_data),
)
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -8,11 +8,14 @@ import os
import pytest
import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@pytest.fixture

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import Attachment, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
if not client.initialize():
return
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
if not models:
print("No models found, skipping chat completion test")
return
model_id = next(m.identifier for m in models if "8b" in m.identifier.lower())
print(f"Using model: {model_id}")
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response (non-stream):")
print(response)
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)
print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
print("\nAgent test:")
agent_config = AgentConfig(
model=model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": "greedy",
"temperature": 1.0,
"top_p": 0.9,
},
tools=(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
}
]
if os.getenv("BRAVE_SEARCH_API_KEY")
else []
)
+ (
[
{
"type": "code_interpreter",
}
]
),
tool_choice="required",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
agent = Agent(client, agent_config)
user_prompts = [
"Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
]
user_prompts = [
(
"Here is a csv, can you describe it ?",
[
Attachment(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="test/csv",
)
],
),
("Which year ended with the highest inflation ?", None),
(
"What macro economic situations that led to such high inflation in that period?",
None,
),
("Plot average yearly inflation as a time series", None),
]
session_id = agent.create_session("test-session")
for prompt, attachments in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
attachments=attachments,
session_id=session_id,
)
for log in AgentEventLogger().log(response):
log.print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
main(args.config_path)

View file

@ -1,128 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from pydantic import BaseModel
from llama_stack.providers.utils.telemetry import tracing
T = TypeVar("T")
def serialize_value(value: Any) -> str:
"""Helper function to serialize values to string representation."""
try:
if isinstance(value, BaseModel):
return value.model_dump_json()
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
return json.dumps([item.model_dump_json() for item in value])
elif hasattr(value, "to_dict"):
return json.dumps(value.to_dict())
elif isinstance(value, (dict, list, int, float, str, bool)):
return json.dumps(value)
else:
return str(value)
except Exception:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable) -> Callable:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
span_attributes = {
"class": class_name,
"method": method_name,
"type": span_type,
"args": serialize_value(args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
original_init_subclass = getattr(cls, "__init_subclass__", None)
def __init_subclass__(cls_child, **kwargs): # noqa: N807
if original_init_subclass:
original_init_subclass(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -1,16 +1,41 @@
# LLama Stack UI
# (Experimental) LLama Stack UI
[!NOTE] This is a work in progress.
## Docker Setup
## Prerequisite
- Start up Llama Stack Server
```
llama stack run
```
:warning: This is a work in progress.
## Running Streamlit App
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
```
llama stack build --template together --image-type conda
llama stack run together
```
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
$ llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
```
```bash
$ llama-stack-client eval_tasks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
3. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py

View file

@ -129,7 +129,7 @@ def application_evaluation_page():
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json(
score_res.to_json(),

View file

@ -232,7 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json(eval_res, expanded=2)

View file

@ -11,7 +11,9 @@ from modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,

View file

@ -74,7 +74,9 @@ def rag_chat_page():
]
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,
@ -116,8 +118,6 @@ def rag_chat_page():
with st.chat_message(message["role"]):
st.markdown(message["content"])
selected_model = llama_stack_api.client.models.list()[0].identifier
agent_config = AgentConfig(
model=selected_model,
instructions=system_prompt,