Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-04-10 11:01:51 -05:00 committed by GitHub
commit 13c660f5a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 10986 additions and 93 deletions

View file

@ -312,6 +312,11 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Optional[str] = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -4,12 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
from typing import Dict, List
import os
from typing import Any, Dict, List
import yaml
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
def stack_apis() -> List[Api]:
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
Args:
config: Optional StackRunConfig containing the external providers directory path
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
if config and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret

View file

@ -351,6 +351,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):

View file

@ -608,8 +608,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = (await self.list_tools(toolgroup_id)).data
for tool in tools:
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -218,7 +218,7 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
await register_resources(run_config, impls)
return impls

View file

@ -1,7 +1,7 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.9-slim
FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \

View file

@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
3. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
```
## Environment Variables

View file

@ -24,6 +24,7 @@ def main():
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
@ -39,6 +40,7 @@ def main():
"Playground": [
chat_page,
rag_page,
tool_page,
application_evaluation_page,
native_evaluation_page,
],

View file

@ -19,6 +19,7 @@ class LlamaStackApi:
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack.distribution.ui.modules.api import llama_stack_api
def tool_chat_page():
st.title("🛠 Tools")
client = llama_stack_api.client
models = client.models.list()
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
tool_groups = client.toolgroups.list()
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
def reset_agent():
st.session_state.clear()
st.cache_resource.clear()
with st.sidebar:
st.subheader("Model")
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
st.subheader("Builtin Tools")
toolgroup_selection = st.pills(
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
)
st.subheader("MCP Servers")
mcp_selection = st.pills(
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
)
toolgroup_selection.extend(mcp_selection)
active_tool_list = []
for toolgroup_id in toolgroup_selection:
active_tool_list.extend(
[
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
for t in client.tools.list(toolgroup_id=toolgroup_id)
]
)
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
st.json(active_tool_list)
@st.cache_resource
def create_agent():
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={
"strategy": {"type": "greedy"},
},
)
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input(placeholder=""):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
def response_generator(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
if response.event.payload.event_type == "step_progress":
if hasattr(response.event.payload.delta, "text"):
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
yield " 🛠 "
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response})
tool_chat_page()

View file

@ -1,4 +1,5 @@
streamlit
pandas
llama-stack-client>=0.0.55
llama-stack-client>=0.2.1
streamlit-option-menu
llama-stack>=0.2.1

View file

@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
# Update our tracked values with any changes made during this iteration
for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get()
yield item
except StopAsyncIteration:

View file

@ -1,155 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
token = context_var.set("start_value")
# Create an async generator that raises an exception
async def exception_generator():
yield context_var.get()
context_var.set("modified")
raise ValueError("Test exception")
yield None # This will never be reached
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
# First iteration should work
value = await wrapped_gen.__anext__()
assert value == "start_value"
# Second iteration should raise the exception
with pytest.raises(ValueError, match="Test exception"):
await wrapped_gen.__anext__()
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
token = context_var.set("value")
# Create an empty async generator
async def empty_generator():
if False: # This condition ensures the generator yields nothing
yield None
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
# The generator should raise StopAsyncIteration immediately
with pytest.raises(StopAsyncIteration):
await wrapped_gen.__anext__()
# Context variable should remain unchanged
assert context_var.get() == "value"
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.
This simulates the real-world scenario where:
1. A new event loop is created for each streaming request
2. The async generator runs inside that loop
3. There are multiple levels of nested generators
4. Context needs to be preserved across these boundaries
"""
# Create context variables
request_id = ContextVar("request_id", default=None)
user_id = ContextVar("user_id", default=None)
# Set initial values
# Results container to verify values across thread boundaries
results = []
# Inner-most generator (level 2)
async def inner_generator():
# Should have the context from the outer scope
yield (1, request_id.get(), user_id.get())
# Modify one context variable
user_id.set("user-modified")
# Should reflect the modification
yield (2, request_id.get(), user_id.get())
# Middle generator (level 1)
async def middle_generator():
inner_gen = inner_generator()
# Forward the first yield from inner
item = await inner_gen.__anext__()
yield item
# Forward the second yield from inner
item = await inner_gen.__anext__()
yield item
request_id.set("req-modified")
# Add our own yield with both modified variables
yield (3, request_id.get(), user_id.get())
# Function to run in a separate thread with a new event loop
def run_in_new_loop():
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Outer generator (runs in the new loop)
async def outer_generator():
request_id.set("req-12345")
user_id.set("user-6789")
# Wrap the middle generator
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
# Process all items from the middle generator
async for item in wrapped_gen:
# Store results for verification
results.append(item)
# Run the outer generator in the new loop
loop.run_until_complete(outer_generator())
finally:
loop.close()
# Run the generator chain in a separate thread with a new event loop
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
future.result() # Wait for completion
# Verify the results
assert len(results) == 3
# First yield should have original values
assert results[0] == (1, "req-12345", "user-6789")
# Second yield should have modified user_id
assert results[1] == (2, "req-12345", "user-modified")
# Third yield should have both modified values
assert results[2] == (3, "req-modified", "user-modified")

View file

@ -119,17 +119,16 @@ class Llama3:
torch.set_default_device(device)
else:
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_dtype(torch.half)
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
else:
torch.set_default_dtype(torch.half)
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
model = build_model()
print("Loading state dict...")

View file

@ -70,6 +70,9 @@ class ModelArgs(BaseModel):
attention_chunk_size: Optional[int] = None
rope_theta: float = 500000
use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
@ -92,4 +95,14 @@ class ModelArgs(BaseModel):
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
if self.use_scaled_rope:
# NOTE: ideally these values should have come from params.json. However, we have
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
# specific values.
if self.rope_scaling_factor is None:
self.rope_scaling_factor = 16
if self.rope_high_freq_factor is None:
self.rope_high_freq_factor = 1
return self

View file

@ -23,37 +23,25 @@ from .ffn import FeedForward
from .moe import MoE
def rmsnorm(x, eps):
def _norm(y):
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
return _norm(x.float()).type_as(x)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return rmsnorm(x, self.eps) * self.weight
class L2Norm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
use_scaled: bool,
scale_factor: float,
high_freq_factor: float,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
@ -174,9 +169,7 @@ class Attention(nn.Module):
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self.norm_eps = args.norm_eps
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
@ -220,8 +213,8 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = self.qk_norm(xq)
xk = self.qk_norm(xk)
xq = rmsnorm(xq, self.norm_eps)
xk = rmsnorm(xk, self.norm_eps)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
@ -362,6 +355,8 @@ class Transformer(nn.Module):
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
args.rope_scaling_factor,
args.rope_high_freq_factor,
)
vision_args = self.args.vision_args
if vision_args:

View file

@ -91,7 +91,7 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")

View file

@ -56,9 +56,11 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|text_post_train_reserved_special_token_6|>",
"<|text_post_train_reserved_special_token_7|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
"text_post_train", 61, 8
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
# 200080, ..., 201133

View file

@ -65,7 +65,7 @@ class Int4Weights(
Int4ScaledWeights,
collections.namedtuple(
"Int4Weights",
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
["weight", "scale", "zero_point", "shape"],
),
):
pass
@ -184,20 +184,13 @@ def quantize_fp8(
@torch.inference_mode()
def quantize_int4(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
if w.ndim >= 3:
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
@ -212,7 +205,6 @@ def quantize_int4(
scale=scale.to(output_device),
zero_point=zero_point.to(output_device),
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@ -247,26 +239,18 @@ def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input INT4.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Int4Weights(
weight=w.to(torch.int8).to(device=output_device),
scale=scale.to(device=output_device),
zero_point=zero_point.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)

View file

@ -89,7 +89,6 @@ class ChatAgent(ShieldRunnerMixin):
self,
agent_id: str,
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
@ -99,7 +98,6 @@ class ChatAgent(ShieldRunnerMixin):
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api

View file

@ -7,7 +7,6 @@
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator, List, Optional, Union
@ -64,7 +63,6 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -107,7 +105,6 @@ class MetaReferenceAgentsImpl(Agents):
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,

View file

@ -259,7 +259,7 @@ class Llama3Generator:
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
model_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -284,7 +284,7 @@ class Llama3Generator:
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,

View file

@ -307,9 +307,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:
response = await self.client.ps()
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(

View file

@ -13,7 +13,7 @@ The `llamastack/distribution-{{ name }}` distribution consists of the following
{{ providers_table }}
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
You can use this distribution if you want to run an independent vLLM server for inference.
{% if run_config_env_vars %}
### Environment Variables
@ -28,6 +28,83 @@ The following environment variables can be configured:
## Setting up vLLM server
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
that we only use GPUs here for demonstration purposes.
### Setting up vLLM server on AMD GPU
AMD provides two main vLLM container options:
- rocm/vllm: Production-ready container
- rocm/vllm-dev: Development container with the latest vLLM features
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
Here is a sample script to start a ROCm vLLM server locally via Docker:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
export VLLM_DIMG="rocm/vllm-dev:main"
docker run \
--pull always \
--ipc=host \
--privileged \
--shm-size 16g \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--cap-add=SYS_PTRACE \
--cap-add=CAP_SYS_ADMIN \
--security-opt seccomp=unconfined \
--security-opt apparmor=unconfined \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
-p $INFERENCE_PORT:$INFERENCE_PORT \
-v ~/.cache/huggingface:/root/.cache/huggingface \
$VLLM_DIMG \
python -m vllm.entrypoints.openai.api_server \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
export VLLM_DIMG="rocm/vllm-dev:main"
docker run \
--pull always \
--ipc=host \
--privileged \
--shm-size 16g \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--cap-add=SYS_PTRACE \
--cap-add=CAP_SYS_ADMIN \
--security-opt seccomp=unconfined \
--security-opt apparmor=unconfined \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
-p $SAFETY_PORT:$SAFETY_PORT \
-v ~/.cache/huggingface:/root/.cache/huggingface \
$VLLM_DIMG \
python -m vllm.entrypoints.openai.api_server \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
### Setting up vLLM server on NVIDIA GPU
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
```bash