Merge conflicts

This commit is contained in:
Chantal D Gama Rose 2024-12-03 19:00:27 -08:00
commit 5b027d2de5
198 changed files with 6140 additions and 3477 deletions

View file

@ -14,15 +14,19 @@ import httpx
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .agents import * # noqa: F403
import logging
from .event_logger import EventLogger
log = logging.getLogger(__name__)
load_dotenv()
@ -93,13 +97,12 @@ class AgentsClient(Agents):
try:
jdata = json.loads(data)
if "error" in jdata:
cprint(data, "red")
log.error(data)
continue
yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
log.error(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
@ -125,7 +128,7 @@ async def _run_agent(
)
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
log.info(f"User> {content}", color="white", attrs=["bold"])
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
@ -138,9 +141,9 @@ async def _run_agent(
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
async for event, logger in EventLogger().log(iterator):
if logger is not None:
log.info(logger)
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):

View file

@ -40,7 +40,7 @@ class ModelsClient(Models):
response = await client.post(
f"{self.base_url}/models/register",
json={
"model": json.loads(model.json()),
"model": json.loads(model.model_dump_json()),
},
headers={"Content-Type": "application/json"},
)

View file

@ -8,7 +8,6 @@ import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
import importlib
import os
import shutil
from functools import lru_cache
@ -17,10 +16,10 @@ from pathlib import Path
import pkg_resources
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@lru_cache()
@ -224,6 +223,10 @@ class StackBuild(Subcommand):
for i, provider_type in enumerate(provider_types):
pid = provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
@ -258,6 +261,7 @@ class StackBuild(Subcommand):
) -> None:
import json
import os
import re
import yaml
from termcolor import cprint
@ -286,17 +290,19 @@ class StackBuild(Subcommand):
os.makedirs(build_dir, exist_ok=True)
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(template_path, run_config_file)
module_name = f"llama_stack.templates.{template_name}"
module = importlib.import_module(module_name)
distribution_template = module.get_distribution_template()
with open(template_path, "r") as f:
yaml_content = f.read()
# Find all ${env.VARIABLE} patterns
env_vars = set(re.findall(r"\${env\.([A-Za-z0-9_]+)}", yaml_content))
cprint("Build Successful! Next steps: ", color="green")
env_vars = ", ".join(distribution_template.run_config_env_vars.keys())
cprint(
f" 1. Set the environment variables: {env_vars}",
f" 1. Set the environment variables: {list(env_vars)}",
color="green",
)
cprint(
f" 2. `llama stack run {run_config_file}`",
f" 2. Run: `llama stack run {template_name}`",
color="green",
)
else:

View file

@ -5,9 +5,12 @@
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
REPO_ROOT = Path(__file__).parent.parent.parent.parent
class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
@ -48,8 +51,6 @@ class StackRun(Subcommand):
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from pathlib import Path
import pkg_resources
import yaml
@ -66,19 +67,27 @@ class StackRun(Subcommand):
return
config_file = Path(args.config)
if not config_file.exists() and not args.config.endswith(".yaml"):
has_yaml_suffix = args.config.endswith(".yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not args.config.endswith(".yaml"):
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to docker dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not args.config.endswith(".yaml"):
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
@ -92,6 +101,7 @@ class StackRun(Subcommand):
)
return
print(f"Using config file: {config_file}")
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)

View file

@ -4,14 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from enum import Enum
from typing import List
import pkg_resources
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
@ -22,6 +21,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
@ -93,7 +94,7 @@ def print_pip_install_help(providers: Dict[str, List[Provider]]):
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
log.info(f"\tpip install {special_dep}")
print()
@ -133,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
return_code = run_with_pty(args)
if return_code != 0:
cprint(
log.error(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return return_code

View file

@ -122,7 +122,7 @@ add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$build_name"]
EOF

View file

@ -3,12 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
@ -22,6 +22,8 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
@ -50,7 +52,7 @@ def configure_api_providers(
is_nux = len(config.providers) == 0
if is_nux:
print(
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -76,18 +78,18 @@ def configure_api_providers(
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint(
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
updated_providers = []
for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`")
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
print("")
logger.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
@ -96,17 +98,17 @@ def configure_api_providers(
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:
others = ", ".join(plist[i:])
print(
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
print(f"> Configuring provider `({provider_type})`")
logger.info(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
@ -121,7 +123,7 @@ def configure_api_providers(
),
)
)
print("")
logger.info("")
config.providers[api_str] = updated_providers
@ -182,7 +184,7 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
return StackRunConfig(**config_dict)
if "routing_table" in config_dict:
print("Upgrading config...")
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
import json
import logging
import threading
from typing import Any, Dict
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)
log.error(f"Error parsing provider data: {e}")
def set_request_provider_data(headers: Dict[str, str]):
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
try:
val = json.loads(val)
except json.JSONDecodeError:
print("Provider data not encoded as a JSON object!", val)
log.error("Provider data not encoded as a JSON object!", val)
return
_THREAD_LOCAL.provider_data_header_value = val

View file

@ -8,11 +8,12 @@ import inspect
from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import logging
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -33,6 +34,8 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
class InvalidProviderError(Exception):
pass
@ -115,14 +118,12 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"])
log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
cprint(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"yellow",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
@ -199,10 +200,10 @@ async def resolve_impls(
)
)
print(f"Resolved {len(sorted_providers)} providers")
log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
print(f" {api_str} => {provider.provider_id}")
print("")
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -339,7 +340,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
print(
log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))

View file

@ -170,13 +170,6 @@ class CommonRoutingTableImpl(RoutingTable):
# Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
# Check for existing registration
if existing_obj and existing_obj.provider_id == obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
)
return existing_obj
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]

View file

@ -16,13 +16,12 @@ import traceback
import warnings
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, Dict, Optional
from pathlib import Path
from typing import Any, Union
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -34,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
@ -45,10 +43,17 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
)
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
@ -110,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
@ -192,7 +136,6 @@ def handle_sigint(app, *args, **kwargs):
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -227,14 +170,10 @@ async def sse_generator(event_gen):
},
}
)
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -249,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
new_params = [
@ -274,14 +211,30 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
default="llamastack-run.yaml",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
@ -303,11 +256,31 @@ def main():
print(f"Error: {str(e)}")
sys.exit(1)
with open(args.yaml_config, "r") as fp:
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
with open(config_file, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
app = FastAPI()
print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
try:
impls = asyncio.run(construct_stack(config))
@ -316,6 +289,8 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
all_endpoints = get_all_api_endpoints()

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from pathlib import Path
from typing import Any, Dict
@ -40,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -93,11 +96,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
method = getattr(impls[api], list_method)
for obj in await method():
print(
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
print("")
log.info("")
class EnvVarError(Exception):

View file

@ -0,0 +1,11 @@
# LLama Stack UI
[!NOTE] This is a work in progress.
## Running Streamlit App
```
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from modules.api import LlamaStackEvaluation
from modules.utils import process_dataset
EVALUATION_API = LlamaStackEvaluation()
def main():
# Add collapsible sidebar
with st.sidebar:
# Add collapse button
if "sidebar_state" not in st.session_state:
st.session_state.sidebar_state = True
if st.session_state.sidebar_state:
st.title("Navigation")
page = st.radio(
"Select a Page",
["Application Evaluation"],
index=0,
)
else:
page = "Application Evaluation" # Default page when sidebar is collapsed
# Main content area
st.title("🦙 Llama Stack Evaluations")
if page == "Application Evaluation":
application_evaluation_page()
def application_evaluation_page():
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = EVALUATION_API.list_scoring_functions()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = EVALUATION_API.list_models()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = EVALUATION_API.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
)
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient
class LlamaStackEvaluation:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
},
)
def list_scoring_functions(self):
"""List all available scoring functions"""
return self.client.scoring_functions.list()
def list_models(self):
"""List all available judge models"""
return self.client.models.list()
def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(
input_rows=[row], scoring_functions=scoring_params
)

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None

View file

@ -0,0 +1,3 @@
streamlit
pandas
llama-stack-client>=0.0.55

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import errno
import logging
import os
import pty
import select
@ -13,7 +14,7 @@ import subprocess
import sys
import termios
from termcolor import cprint
log = logging.getLogger(__name__)
# run a command in a pseudo-terminal, with interrupt handling,
@ -29,7 +30,7 @@ def run_with_pty(command):
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
log.info("\nCtrl-C detected. Aborting...")
try:
# Set up the signal handler
@ -100,6 +101,6 @@ def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
log.error(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
path = os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
return path.replace(":", "-")
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))

View file

@ -6,6 +6,7 @@
import inspect
import json
import logging
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
print(f"Invalid {discriminator}. Please try again.")
log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
@ -180,7 +183,7 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except KeyError:
print(
log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
@ -197,7 +200,7 @@ def prompt_for_config(
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
@ -213,7 +216,7 @@ def prompt_for_config(
existing_value,
)
elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:")
log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
@ -240,7 +243,7 @@ def prompt_for_config(
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
log.error("This field is required. Please provide a value.")
continue
else:
try:
@ -264,12 +267,12 @@ def prompt_for_config(
value = [element_type(item) for item in value]
except json.JSONDecodeError:
print(
log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
continue
except ValueError as e:
print(f"{str(e)}")
log.error(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
@ -281,7 +284,7 @@ def prompt_for_config(
)
except json.JSONDecodeError:
print(
log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
@ -298,7 +301,7 @@ def prompt_for_config(
value = field_type(user_input)
except ValueError:
print(
log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
continue
@ -311,6 +314,6 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except ValueError as e:
print(f"Validation error: {str(e)}")
log.error(f"Validation error: {str(e)}")
return config_type(**config_data)

View file

@ -6,6 +6,7 @@
import asyncio
import copy
import logging
import os
import re
import secrets
@ -19,7 +20,6 @@ from urllib.parse import urlparse
import httpx
from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -43,6 +43,8 @@ from .tools.builtin import (
)
from .tools.safety import SafeTool
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
@ -111,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
msg = m.model_copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
async def create_session(self, name: str) -> str:
@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
cprint(
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
)
output_message = chunk
continue
@ -397,17 +396,11 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
if msg.role == Role.user.value:
color = "blue"
elif msg.role == Role.ipython.value:
color = "yellow"
else:
color = None
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
cprint(f"{msg_str}", color=color)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -506,12 +499,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -525,10 +518,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
cprint(f"Partial message: {str(message)}", color="green")
log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message]
else:
cprint(f"{str(message)}", color="green")
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
@ -740,9 +733,8 @@ class ChatAgent(ShieldRunnerMixin):
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
@ -786,7 +778,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -826,20 +818,3 @@ async def execute_tool_call_maybe(
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)

View file

@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents):
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.json(),
value=agent_config.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import datetime
@ -15,6 +15,8 @@ from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
@ -37,7 +39,7 @@ class AgentPersistence:
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)
return session_id
@ -58,13 +60,13 @@ class AgentPersistence:
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
@ -78,7 +80,7 @@ class AgentPersistence:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns

View file

@ -10,8 +10,6 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
@ -36,7 +34,6 @@ async def generate_rag_query(
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query

View file

@ -5,14 +5,16 @@
# the root directory of this source tree.
import asyncio
import logging
from typing import List
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
class SafetyException(Exception): # noqa: N818
def __init__(self, violation: SafetyViolation):
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{identifier} raised a warning",
color="red",
)
log.warning(f"[Warn]{identifier} raised a warning")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import logging
import re
import tempfile
@ -12,7 +13,6 @@ from abc import abstractmethod
from typing import List, Optional
import requests
from termcolor import cprint
from .ipython_tool.code_execution import (
CodeExecutionContext,
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
import base64
import io
import json as _json
import logging
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
log = logging.getLogger(__name__)
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
@ -80,7 +83,7 @@ def show():
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp)
log.info(resp)
FigureCanvas = CustomFigureCanvas

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.json(),
value=task_def.model_dump_json(),
)
self.eval_tasks[task_def.identifier] = task_def

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel):
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
@ -54,6 +56,7 @@ class MetaReferenceInferenceConfig(BaseModel):
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
@ -64,3 +67,16 @@ class MetaReferenceInferenceConfig(BaseModel):
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import math
import os
import sys
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
@ -50,6 +50,8 @@ from .config import (
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -185,7 +187,7 @@ class Llama:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
def __init__(
@ -221,7 +223,7 @@ class Llama:
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
bsz = 1
@ -231,9 +233,7 @@ class Llama:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import logging
from typing import AsyncGenerator, List
@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()

View file

@ -11,6 +11,7 @@
# the root directory of this source tree.
import json
import logging
import multiprocessing
import os
import tempfile
@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
log = logging.getLogger(__name__)
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
print("quitting generation loop because request was cancelled")
log.info(
"quitting generation loop because request was cancelled"
)
break
if mp_rank_0():
send_obj(EndSentinel())
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
log.exception("exception in generation loop")
traceback.print_exc()
if mp_rank_0():
send_obj(ExceptionResponse(error=str(e)))
@ -252,7 +255,7 @@ def worker_process_entrypoint(
except StopIteration:
break
print("[debug] worker process done")
log.info("[debug] worker process done")
def launch_dist_group(
@ -313,7 +316,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print("Loaded model...")
log.info("Loaded model...")
return request_socket, process
@ -361,7 +364,7 @@ class ModelParallelProcessGroup:
break
if isinstance(obj, ExceptionResponse):
print(f"[debug] got exception {obj.error}")
log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):

View file

@ -8,14 +8,20 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
raise
import torch

View file

@ -7,6 +7,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
import os
from typing import Any, Dict, List, Optional
@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
def swiglu_wrapper(
self,
@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import os
import shutil
import sys
@ -22,12 +23,18 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from torch.nn.parameter import Parameter
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
quantize_fp8,
)
log = logging.getLogger(__name__)
def main(
ckpt_dir: str,
@ -36,7 +43,6 @@ def main(
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
@ -99,7 +105,7 @@ def main(
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
@ -112,7 +118,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -124,7 +129,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -136,7 +140,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():

View file

@ -9,7 +9,7 @@
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
cd $(dirname "$(realpath "$0")")
MASTER_HOST=$1
RUN_ID=$2
@ -21,7 +21,7 @@ NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \

View file

@ -37,19 +37,22 @@ class VLLMConfig(BaseModel):
@classmethod
def sample_run_config(cls):
return {
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)

View file

@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex):
np.savetxt(buffer, np_index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}
@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.json(),
value=memory_bank.model_dump_json(),
)
# Store in cache

View file

@ -4,10 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type
class ConsoleConfig(BaseModel): ...
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT

View file

@ -4,8 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Optional
from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig
@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry):
span_name = ".".join(names) if names else None
formatted = format_event(event, span_name)
if self.config.log_format == LogFormat.JSON:
formatted = format_event_json(event, span_name)
else:
formatted = format_event_text(event, span_name)
if formatted:
print(formatted)
@ -69,7 +76,7 @@ SEVERITY_COLORS = {
}
def format_event(event: Event, span_name: str) -> Optional[str]:
def format_event_text(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]:
return None
return f"Unknown event type: {event}"
def format_event_json(event: Event, span_name: str) -> Optional[str]:
base_data = {
"timestamp": event.timestamp.isoformat(),
"trace_id": event.trace_id,
"span_id": event.span_id,
"span_name": span_name,
}
if isinstance(event, UnstructuredLogEvent):
base_data.update(
{"type": "log", "severity": event.severity.name, "message": event.message}
)
return json.dumps(base_data)
elif isinstance(event, StructuredLogEvent):
return None
return json.dumps({"error": f"Unknown event type: {event}"})

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from termcolor import cprint
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
@ -49,7 +49,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)
violation = None

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
import torch
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
@ -93,9 +94,8 @@ class PromptGuardShield:
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
log.info(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
violation = None

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -5,11 +5,17 @@
# the root directory of this source tree.
from typing import Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BraintrustScoringConfig
class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],

View file

@ -12,9 +12,12 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
@ -24,7 +27,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
):
def __init__(
self,
config: BraintrustScoringConfig,
@ -79,12 +84,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None:
# api key is in the request headers
if self.config.openai_api_key is None:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -105,6 +123,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
@ -118,6 +137,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:

View file

@ -6,4 +6,8 @@
from llama_stack.apis.scoring import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
params=None,
provider_id="braintrust",
provider_resource_id="answer-correctness",

View file

@ -84,7 +84,7 @@ llm_as_judge_405b_simpleqa = ScoringFn(
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-405b-simpleqa",
params=LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template=GRADER_TEMPLATE,
judge_score_regexes=[r"(A|B|C)"],
),

View file

@ -17,6 +17,16 @@ from llama_stack.distribution.datatypes import (
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::meta-reference",
@ -48,16 +58,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
Api.inference,
],
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::code-scanner",

View file

@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class HuggingfaceDatasetIOConfig(BaseModel):

View file

@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -4,11 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
@json_schema_type
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -4,11 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from ._config import NVIDIAConfig
from ._nvidia import NVIDIAInferenceAdapter
from llama_stack.apis.inference import Inference
from .config import NVIDIAConfig
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter:
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .nvidia import NVIDIAInferenceAdapter
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class NVIDIAConfig(BaseModel):
"""
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
ResponseFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from . import NVIDIAConfig
from .openai_utils import (
convert_chat_completion_request,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
_MODEL_ALIASES = [
build_model_alias(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_model_alias(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. "
"Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY",
timeout=self._config.timeout,
)
def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError()
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
await check_health(self._config) # this raises errors
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
),
n=1,
)
try:
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
if stream:
return convert_openai_chat_completion_stream(response)
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])

View file

@ -0,0 +1,581 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
StopReason,
TokenLogProbs,
ToolCall,
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage,
UserMessage,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "ipython":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=message.content, # TODO(mf): handle image content
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return out
def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(
request.response_format, JsonSchemaResponseFormat
):
raise ValueError(
f"Unsupported response format: {request.response_format}. "
"Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
OpenAI ChoiceLogprobs:
content: Optional[List[ChatCompletionTokenLogprob]]
OpenAI ChatCompletionTokenLogprob:
token: str
logprob: float
top_logprobs: List[TopLogprob]
OpenAI TopLogprob:
token: str
logprob: float
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content
]
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
"""
Convert an OpenAI Choice into a ChatCompletionResponse.
OpenAI Choice:
message: ChatCompletionMessage
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
OpenAI ChatCompletionMessage:
role: Literal["assistant"]
content: Optional[str]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
->
ChatCompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
OpenAI ChatCompletionChunk:
choices: List[Choice]
OpenAI Choice: # different from the non-streamed Choice
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
logprobs: Optional[ChoiceLogprobs]
OpenAI ChoiceDelta:
content: Optional[str]
role: Optional[Literal["system", "user", "assistant", "tool"]]
tool_calls: Optional[List[ChoiceDeltaToolCall]]
OpenAI ChoiceDeltaToolCall:
index: int
id: Optional[str]
function: Optional[ChoiceDeltaToolCallFunction]
type: Optional[Literal["function"]]
OpenAI ChoiceDeltaToolCallFunction:
name: Optional[str]
arguments: Optional[str]
->
ChatCompletionResponseStreamChunk:
event: ChatCompletionResponseEvent
ChatCompletionResponseEvent:
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]]
stop_reason: Optional[StopReason]
ChatCompletionResponseEventType:
start = "start"
progress = "progress"
complete = "complete"
ToolCallDelta:
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
ToolCall:
call_id: str
tool_name: str
arguments: str
ToolCallParseStatus:
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
StopReason:
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
# we implement NIM specific semantics, the main difference from OpenAI
# is that tool_calls are always produced as a complete call. there is no
# intermediate / partial tool call streamed. because of this, we can
# simplify the logic and not concern outselves with parse_status of
# started/in_progress/failed. we can always assume success.
#
# a stream of ChatCompletionResponseStreamChunk consists of
# 0. a start event
# 1. zero or more progress events
# - each progress event has a delta
# - each progress event may have a stop_reason
# - each progress event may have logprobs
# - each progress event may have tool_calls
# if a progress event has tool_calls,
# it is fully formed and
# can be emitted with a parse_status of success
# 2. a complete event
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content,
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest"
)
# NIM only produces fully formed tool calls, so we can assume success
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.success,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content or "", # content is not optional
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Tuple
import httpx
from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
async def _get_health(url: str) -> Tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NVIDIAConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
print("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
import httpx
@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
log = logging.getLogger(__name__)
model_aliases = [
build_model_alias(
@ -57,18 +59,26 @@ model_aliases = [
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
@ -81,6 +91,14 @@ model_aliases = [
"llama3.2-vision",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_model_alias(
@ -105,7 +123,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print(f"checking connectivity to Ollama at `{self.url}`...")
log.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:

View file

@ -37,6 +37,18 @@ class InferenceEndpointImplConfig(BaseModel):
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"endpoint_name": endpoint_name,
"api_token": api_token,
}
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
@ -47,3 +59,15 @@ class InferenceAPIImplConfig(BaseModel):
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
repo: str = "${env.INFERENCE_MODEL}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"huggingface_repo": repo,
"api_token": api_token,
}

View file

@ -17,6 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -34,7 +38,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__)
log = logging.getLogger(__name__)
def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(Inference, ModelsProtocolPrivate):
@ -44,45 +59,39 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def register_model(self, model: Model) -> None:
pass
async def list_models(self) -> List[Model]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
Model(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
]
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -176,7 +185,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -186,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -241,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,
@ -256,7 +266,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
@ -264,7 +274,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
print(f"Initializing TGI client with url={config.url}")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -3,6 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
def build_model_aliases():
return [
build_model_alias(
@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
print(f"Initializing VLLM client with base_url={self.config.url}")
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import List
from urllib.parse import urlparse
@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
)
log = logging.getLogger(__name__)
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {doc}")
log.exception(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def initialize(self) -> None:
try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
import traceback
traceback.print_exc()
log.exception("Could not connect to Chroma server")
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None:
@ -109,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
metadata={"bank": memory_bank.model_dump_json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(VectorMemoryBank, data)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
@ -147,9 +142,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
@ -159,8 +152,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return index

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import List, Tuple
import psycopg2
@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorConfig
log = logging.getLogger(__name__)
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.cache = {}
async def initialize(self) -> None:
print(f"Initializing PGVector memory adapter with config: {self.config}")
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
try:
self.conn = psycopg2.connect(
host=self.config.host,
@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
version = check_extension_version(self.cursor)
if version:
print(f"Vector extension version: {version}")
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
"""
)
except Exception as e:
import traceback
traceback.print_exc()
log.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
@ -201,10 +202,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
@ -213,8 +211,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import traceback
import logging
import uuid
from typing import Any, Dict, List
@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
)
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex):
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
traceback.print_exc()
log.exception("Failed to parse chunk")
continue
chunks.append(chunk)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List, Optional
@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import WeaviateConfig, WeaviateRequestProviderData
log = logging.getLogger(__name__)
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
log.exception(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)

View file

@ -4,9 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -4,24 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import threading
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
_GLOBAL_STORAGE = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type,
attributes={
"message": event.message,
"severity": event.severity.value,
**event.attributes,
},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
up_down_counter = self._get_or_create_up_down_counter(
event.metric, event.unit
)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
span.set_status(status)
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
# Remove from active spans
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())
raise NotImplementedError("Trace retrieval not implemented yet")

View file

@ -6,10 +6,14 @@
import pytest
from ..agents.fixtures import AGENTS_FIXTURES
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
@ -20,6 +24,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
@ -30,6 +37,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
@ -40,6 +50,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
@ -75,6 +88,9 @@ def pytest_generate_tests(metafunc):
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
"agents": AGENTS_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -40,14 +40,30 @@ async def eval_stack(request):
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
for key in [
"datasetio",
"eval",
"scoring",
"inference",
"agents",
"safety",
"memory",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
[
Api.eval,
Api.datasetio,
Api.inference,
Api.scoring,
Api.agents,
Api.safety,
Api.memory,
],
providers,
provider_data,
)

View file

@ -6,6 +6,8 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from .fixtures import INFERENCE_FIXTURES
@ -67,11 +69,12 @@ def pytest_generate_tests(metafunc):
indirect=True,
)
if "inference_stack" in metafunc.fixturenames:
metafunc.parametrize(
"inference_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in INFERENCE_FIXTURES
],
indirect=True,
)
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -18,7 +18,9 @@ from llama_stack.providers.inline.inference.meta_reference import (
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -142,6 +144,35 @@ def inference_bedrock() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_tgi() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="tgi",
provider_type="remote::tgi",
config=TGIImplConfig(
url=get_env_or_fail("TGI_URL"),
api_token=os.getenv("TGI_API_TOKEN", None),
).model_dump(),
)
],
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -175,6 +206,8 @@ INFERENCE_FIXTURES = [
"vllm_remote",
"remote",
"bedrock",
"nvidia",
"tgi",
]

View file

@ -11,7 +11,6 @@ import pytest
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
# -m "meta_reference"
# --env TOGETHER_API_KEY=<your_api_key>
class TestModelRegistration:

View file

@ -89,7 +89,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
@ -135,7 +135,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"inline::meta-reference",
"remote::tgi",
"remote::together",
"remote::fireworks",
@ -194,10 +194,11 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"inline::meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
@ -210,7 +211,15 @@ class TestInference:
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
@ -361,7 +370,10 @@ class TestInference:
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
if not isinstance(
first.event.delta.content, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason

View file

@ -44,7 +44,7 @@ class TestVisionModelInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
@ -78,7 +78,7 @@ class TestVisionModelInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",

View file

@ -10,9 +10,10 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -40,7 +41,9 @@ def scoring_braintrust() -> ProviderFixture:
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
config={},
config=BraintrustScoringConfig(
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
).model_dump(),
)
],
)

View file

@ -5,11 +5,9 @@
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class BedrockBaseConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
default=None,
@ -57,3 +55,7 @@ class BedrockBaseConfig(BaseModel):
default=3600,
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)
@classmethod
def sample_run_config(cls, **kwargs):
return {}

View file

@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool:
]
def supported_inference_models() -> List[str]:
def supported_inference_models() -> List[Model]:
return [
m.descriptor()
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}

View file

@ -7,14 +7,13 @@
import base64
import io
import json
import logging
from typing import Tuple
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from PIL import Image as PIL_Image
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import ModelFamily
@ -29,6 +28,8 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
def content_has_media(content: InterleavedTextMedia):
def _has_media_content(c):
@ -175,11 +176,13 @@ def chat_completion_request_to_messages(
"""
model = resolve_model(llama_model)
if model is None:
cprint(f"Could not resolve model {llama_model}", color="red")
log.error(f"Could not resolve model {llama_model}")
return request.messages
if model.descriptor() not in supported_inference_models():
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
from typing import List, Optional
@ -13,6 +14,8 @@ from psycopg2.extras import DictCursor
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__)
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
@ -43,9 +46,8 @@ class PostgresKVStoreImpl(KVStore):
"""
)
except Exception as e:
import traceback
traceback.print_exc()
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import base64
import io
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
@ -16,13 +17,14 @@ import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403
log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
@ -35,7 +37,7 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
if loaded_model is not None:
return loaded_model
print(f"Loading sentence transformer for {model}...")
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
@ -92,7 +94,7 @@ def content_from_data(data_url: str) -> str:
return "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
cprint("Could not extract content from data_url properly.", color="red")
log.error("Could not extract content from data_url properly.")
return ""

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -17,8 +17,10 @@ from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403
log = logging.getLogger(__name__)
def generate_short_uuid(len: int = 12):
def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
@ -40,7 +42,7 @@ class BackgroundLogger:
try:
self.log_queue.put_nowait(event)
except queue.Full:
print("Log queue is full, dropping event")
log.error("Log queue is full, dropping event")
def _process_logs(self):
while True:
@ -121,18 +123,19 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None):
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
print("No Telemetry implementation set. Skipping trace initialization...")
log.info("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid()
trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):

View file

@ -50,7 +50,7 @@ def process_template(template_dir: Path, progress) -> None:
template.save_distribution(
yaml_output_dir=REPO_ROOT / "llama_stack" / "templates" / template.name,
doc_output_dir=REPO_ROOT
/ "docs/source/getting_started/distributions"
/ "docs/source/distributions"
/ f"{template.distro_type}_distro",
)
else:
@ -103,7 +103,7 @@ def generate_dependencies_file():
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
with open(deps_file, "w") as f:
json.dump(distribution_deps, f, indent=2)
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main():

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bedrock import get_distribution_template # noqa: F401

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::bedrock"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["remote::bedrock"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
return DistributionTemplate(
name="bedrock",
distro_type="self_hosted",
description="Use AWS Bedrock for running LLM inference and safety",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=[],
run_configs={
"run.yaml": RunConfigSettings(),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
},
)

View file

@ -1,9 +1,19 @@
version: '2'
name: bedrock
distribution_spec:
description: Use Amazon Bedrock APIs.
description: Use AWS Bedrock for running LLM inference and safety
docker_image: null
providers:
inference: remote::bedrock
memory: inline::faiss
safety: inline::llama-guard
agents: inline::meta-reference
telemetry: inline::meta-reference
inference:
- remote::bedrock
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- remote::bedrock
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,70 @@
# Bedrock Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} ({{ model.provider_model_id }})`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a AWS Bedrock API Key. You can get one by visiting [AWS Bedrock](https://aws.amazon.com/bedrock/).
## Running Llama Stack with AWS Bedrock
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
```
### Via Conda
```bash
llama stack build --template {{ name }} --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
```

View file

@ -0,0 +1,49 @@
version: '2'
image_name: bedrock
docker_image: null
conda_env: bedrock
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: bedrock
provider_type: remote::bedrock
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/faiss_store.db
safety:
- provider_id: bedrock
provider_type: remote::bedrock
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
models: []
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -1,9 +0,0 @@
name: databricks
distribution_spec:
description: Use Databricks for running LLM inference
providers:
inference: remote::databricks
memory: inline::faiss
safety: inline::llama-guard
agents: meta-reference
telemetry: meta-reference

View file

@ -1,5 +1,15 @@
---
orphan: true
---
# Fireworks Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
@ -43,9 +53,7 @@ LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env FIREWORKS_API_KEY=$FIREWORKS_API_KEY
```
@ -55,6 +63,6 @@ docker run \
```bash
llama stack build --template fireworks --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port $LLAMA_STACK_PORT \
--env FIREWORKS_API_KEY=$FIREWORKS_API_KEY
```

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .hf_endpoint import get_distribution_template # noqa: F401

View file

@ -1,9 +1,19 @@
version: '2'
name: hf-endpoint
distribution_spec:
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
description: Use (an external) Hugging Face Inference Endpoint for running LLM inference
docker_image: null
providers:
inference: remote::hf::endpoint
memory: inline::faiss
safety: inline::llama-guard
agents: inline::meta-reference
telemetry: inline::meta-reference
inference:
- remote::hf::endpoint
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::hf::endpoint"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="hf-endpoint",
provider_type="remote::hf::endpoint",
config=InferenceEndpointImplConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="hf-endpoint",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="hf-endpoint-safety",
)
return DistributionTemplate(
name="hf-endpoint",
distro_type="self_hosted",
description="Use (an external) Hugging Face Inference Endpoint for running LLM inference",
docker_image=None,
template_path=None,
providers=providers,
default_models=[inference_model, safety_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=[inference_model],
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="hf-endpoint-safety",
provider_type="remote::hf::endpoint",
config=InferenceEndpointImplConfig.sample_run_config(
endpoint_name="${env.SAFETY_INFERENCE_ENDPOINT_NAME}",
),
),
]
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"HF_API_TOKEN": (
"hf_...",
"Hugging Face API token",
),
"INFERENCE_ENDPOINT_NAME": (
"",
"HF Inference endpoint name for the main inference model",
),
"SAFETY_INFERENCE_ENDPOINT_NAME": (
"",
"HF Inference endpoint for the safety model",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model served by the HF Inference Endpoint",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Safety model served by the HF Inference Endpoint",
),
},
)

View file

@ -0,0 +1,68 @@
version: '2'
image_name: hf-endpoint
docker_image: null
conda_env: hf-endpoint
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: hf-endpoint
provider_type: remote::hf::endpoint
config:
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
api_token: ${env.HF_API_TOKEN}
- provider_id: hf-endpoint-safety
provider_type: remote::hf::endpoint
config:
endpoint_name: ${env.SAFETY_INFERENCE_ENDPOINT_NAME}
api_token: ${env.HF_API_TOKEN}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint
provider_model_id: null
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: hf-endpoint-safety
provider_model_id: null
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -0,0 +1,55 @@
version: '2'
image_name: hf-endpoint
docker_image: null
conda_env: hf-endpoint
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: hf-endpoint
provider_type: remote::hf::endpoint
config:
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
api_token: ${env.HF_API_TOKEN}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint
provider_model_id: null
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .hf_serverless import get_distribution_template # noqa: F401

Some files were not shown because too many files have changed in this diff Show more