Merge branch 'main' into register_custom_model

This commit is contained in:
Rashmi Pawar 2025-04-16 13:47:33 +05:30
commit 8000b0287f
242 changed files with 221047 additions and 8397 deletions

View file

@ -320,7 +320,7 @@ jobs:
- name: "PR - Update comment"
id: pr_update_comment
if: github.event_name == 'pull_request_target'
uses: thollander/actions-comment-pull-request@65f9e5c9a1f2cd378bd74b2e057c9736982a8e74 # v3.0.1
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
with:
filePath: test-summary.md

View file

@ -0,0 +1,93 @@
name: Test External Providers
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test-external-providers:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install Ollama
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: Pull Ollama image
run: |
ollama pull llama3.2:3b-instruct-fp16
- name: Start Ollama in background
run: |
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
uv pip install -e .
- name: Install Ollama custom provider
run: |
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
uv pip install tests/external-provider/llama-stack-provider-ollama
- name: Create provider configuration
run: |
mkdir -p /tmp/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
for i in {1..30}; do
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
echo "Ollama is running!"
exit 0
fi
sleep 1
done
echo "Ollama failed to start"
ollama ps
ollama.log
exit 1
- name: Start Llama Stack server in background
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source .venv/bin/activate
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
echo "Waiting for Llama Stack server..."
for i in {1..30}; do
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
echo "Llama Stack server is up!"
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
echo "Llama Stack server is using custom Ollama provider"
exit 0
else
echo "Llama Stack server is not using custom Ollama provider"
exit 1
fi
fi
sleep 1
done
echo "Llama Stack server failed to start"
cat server.log
exit 1
- name: run inference tests
run: |
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2

View file

@ -1,5 +1,42 @@
# Changelog
# v0.2.1
Published on: 2025-04-05T23:13:00Z
---
# v0.2.0
Published on: 2025-04-05T19:04:29Z
## Llama 4 Support
Checkout more at https://www.llama.com
---
# v0.1.9
Published on: 2025-03-29T00:52:23Z
### Build and Test Agents
* Agents: Entire document context with attachments
* RAG: Documentation with sqlite-vec faiss comparison
* Getting started: Fixes to getting started notebook.
### Agent Evals and Model Customization
* (**New**) Post-training: Add nemo customizer
### Better Engineering
* Moved sqlite-vec to non-blocking calls
* Don't return a payload on file delete
---
# v0.1.8
Published on: 2025-03-24T01:28:50Z

View file

@ -1,8 +1,10 @@
include pyproject.toml
include llama_stack/templates/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/models/llama/llama4/tokenizer.model
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/inference/*.json
include llama_stack/models/llama/*/*.md
include llama_stack/tests/integration/*.jpg

View file

@ -3,12 +3,72 @@
[![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
[![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack)
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
You can now run Llama 4 models on Llama Stack.
*Note you need 8xH100 GPU-host to run these models*
```bash
pip install -U llama_stack
MODEL="Llama-4-Scout-17B-16E-Instruct"
# get meta url from llama.com
llama model download --source meta --model-id $MODEL --meta-url <META_URL>
# start a llama stack server
INFERENCE_MODEL=meta-llama/$MODEL llama stack build --run --template meta-reference-gpu
# install client to interact with the server
pip install llama-stack-client
```
### CLI
```bash
# Run a chat completion
llama-stack-client --endpoint http://localhost:8321 \
inference chat-completion \
--model-id meta-llama/$MODEL \
--message "write a haiku for meta's llama 4 models"
ChatCompletionResponse(
completion_message=CompletionMessage(content="Whispers in code born\nLlama's gentle, wise heartbeat\nFuture's soft unfold", role='assistant', stop_reason='end_of_turn', tool_calls=[]),
logprobs=None,
metrics=[Metric(metric='prompt_tokens', value=21.0, unit=None), Metric(metric='completion_tokens', value=28.0, unit=None), Metric(metric='total_tokens', value=49.0, unit=None)]
)
```
### Python SDK
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=f"http://localhost:8321")
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
prompt = "Write a haiku about coding"
print(f"User> {prompt}")
response = client.inference.chat_completion(
model_id=model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
)
print(f"Assistant> {response.completion_message.content}")
```
As more providers start supporting Llama 4, you can use them in Llama Stack as well. We are adding to the list. Stay tuned!
### Overview
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.

View file

@ -16,3 +16,14 @@
.hide-title h1 {
display: none;
}
h2, h3, h4 {
font-weight: normal;
}
html[data-theme="dark"] .rst-content div[class^="highlight"] {
background-color: #0b0b0b;
}
pre {
white-space: pre-wrap !important;
word-break: break-all;
}

9
docs/_static/js/detect_theme.js vendored Normal file
View file

@ -0,0 +1,9 @@
document.addEventListener("DOMContentLoaded", function () {
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
const htmlElement = document.documentElement;
if (prefersDark) {
htmlElement.setAttribute("data-theme", "dark");
} else {
htmlElement.setAttribute("data-theme", "light");
}
});

File diff suppressed because it is too large Load diff

View file

@ -40,7 +40,7 @@ paths:
schema:
$ref: '#/components/schemas/AppendRowsRequest'
required: true
/v1/batch-inference/chat-completion:
/v1/inference/batch-chat-completion:
post:
responses:
'200':
@ -60,7 +60,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- BatchInference (Coming Soon)
- Inference
description: ''
parameters: []
requestBody:
@ -69,7 +69,7 @@ paths:
schema:
$ref: '#/components/schemas/BatchChatCompletionRequest'
required: true
/v1/batch-inference/completion:
/v1/inference/batch-completion:
post:
responses:
'200':
@ -89,7 +89,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- BatchInference (Coming Soon)
- Inference
description: ''
parameters: []
requestBody:
@ -148,7 +148,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
- BatchInference (Coming Soon)
description: >-
Generate a chat completion for the given messages using the specified model.
parameters: []
@ -183,7 +183,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
- BatchInference (Coming Soon)
description: >-
Generate a completion for the given content using the specified model.
parameters: []
@ -2131,6 +2131,91 @@ paths:
schema:
$ref: '#/components/schemas/LogEventRequest'
required: true
/v1/openai/v1/chat/completions:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIChatCompletion'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Generate an OpenAI-compatible chat completion for the given messages using
the specified model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/OpenaiChatCompletionRequest'
required: true
/v1/openai/v1/completions:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Generate an OpenAI-compatible completion for the given prompt using the specified
model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/OpenaiCompletionRequest'
required: true
/v1/openai/v1/models:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIListModelsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
description: ''
parameters: []
/v1/post-training/preference-optimize:
post:
responses:
@ -2924,6 +3009,54 @@ components:
- tool_name
- arguments
title: ToolCall
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ToolDefinition:
type: object
properties:
@ -3060,7 +3193,7 @@ components:
BatchChatCompletionRequest:
type: object
properties:
model:
model_id:
type: string
messages_batch:
type: array
@ -3074,26 +3207,8 @@ components:
type: array
items:
$ref: '#/components/schemas/ToolDefinition'
tool_choice:
type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
of the model.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
title: ToolPromptFormat
description: >-
Prompt format for calling custom / zero shot tools.
tool_config:
$ref: '#/components/schemas/ToolConfig'
response_format:
$ref: '#/components/schemas/ResponseFormat'
logprobs:
@ -3108,7 +3223,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- model
- model_id
- messages_batch
title: BatchChatCompletionRequest
BatchChatCompletionResponse:
@ -3176,7 +3291,7 @@ components:
BatchCompletionRequest:
type: object
properties:
model:
model_id:
type: string
content_batch:
type: array
@ -3198,7 +3313,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- model
- model_id
- content_batch
title: BatchCompletionRequest
BatchCompletionResponse:
@ -3250,54 +3365,6 @@ components:
required:
- job_uuid
title: CancelTrainingJobRequest
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ChatCompletionRequest:
type: object
properties:
@ -5980,6 +6047,586 @@ components:
- event
- ttl_seconds
title: LogEventRequest
OpenAIAssistantMessageParam:
type: object
properties:
role:
type: string
const: assistant
default: assistant
description: >-
Must be "assistant" to identify this as the model's response
content:
$ref: '#/components/schemas/InterleavedContent'
description: The content of the model's response
name:
type: string
description: >-
(Optional) The name of the assistant message participant.
tool_calls:
type: array
items:
$ref: '#/components/schemas/ToolCall'
description: >-
List of tool calls. Each tool call is a ToolCall object.
additionalProperties: false
required:
- role
- content
title: OpenAIAssistantMessageParam
description: >-
A message containing the model's (assistant) response in an OpenAI-compatible
chat completion request.
OpenAIDeveloperMessageParam:
type: object
properties:
role:
type: string
const: developer
default: developer
description: >-
Must be "developer" to identify this as a developer message
content:
$ref: '#/components/schemas/InterleavedContent'
description: The content of the developer message
name:
type: string
description: >-
(Optional) The name of the developer message participant.
additionalProperties: false
required:
- role
- content
title: OpenAIDeveloperMessageParam
description: >-
A message from the developer in an OpenAI-compatible chat completion request.
OpenAIMessageParam:
oneOf:
- $ref: '#/components/schemas/OpenAIUserMessageParam'
- $ref: '#/components/schemas/OpenAISystemMessageParam'
- $ref: '#/components/schemas/OpenAIAssistantMessageParam'
- $ref: '#/components/schemas/OpenAIToolMessageParam'
- $ref: '#/components/schemas/OpenAIDeveloperMessageParam'
discriminator:
propertyName: role
mapping:
user: '#/components/schemas/OpenAIUserMessageParam'
system: '#/components/schemas/OpenAISystemMessageParam'
assistant: '#/components/schemas/OpenAIAssistantMessageParam'
tool: '#/components/schemas/OpenAIToolMessageParam'
developer: '#/components/schemas/OpenAIDeveloperMessageParam'
OpenAISystemMessageParam:
type: object
properties:
role:
type: string
const: system
default: system
description: >-
Must be "system" to identify this as a system message
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other
system messages (for example, for formatting tool definitions).
name:
type: string
description: >-
(Optional) The name of the system message participant.
additionalProperties: false
required:
- role
- content
title: OpenAISystemMessageParam
description: >-
A system message providing instructions or context to the model.
OpenAIToolMessageParam:
type: object
properties:
role:
type: string
const: tool
default: tool
description: >-
Must be "tool" to identify this as a tool response
tool_call_id:
type: string
description: >-
Unique identifier for the tool call this response is for
content:
$ref: '#/components/schemas/InterleavedContent'
description: The response content from the tool
additionalProperties: false
required:
- role
- tool_call_id
- content
title: OpenAIToolMessageParam
description: >-
A message representing the result of a tool invocation in an OpenAI-compatible
chat completion request.
OpenAIUserMessageParam:
type: object
properties:
role:
type: string
const: user
default: user
description: >-
Must be "user" to identify this as a user message
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the message, which can include text and other media
name:
type: string
description: >-
(Optional) The name of the user message participant.
additionalProperties: false
required:
- role
- content
title: OpenAIUserMessageParam
description: >-
A message from the user in an OpenAI-compatible chat completion request.
OpenaiChatCompletionRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the model to use. The model must be registered with
Llama Stack and available via the /models endpoint.
messages:
type: array
items:
$ref: '#/components/schemas/OpenAIMessageParam'
description: List of messages in the conversation
frequency_penalty:
type: number
description: >-
(Optional) The penalty for repeated tokens
function_call:
oneOf:
- type: string
- type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) The function call to use
functions:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) List of functions to use
logit_bias:
type: object
additionalProperties:
type: number
description: (Optional) The logit bias to use
logprobs:
type: boolean
description: (Optional) The log probabilities to use
max_completion_tokens:
type: integer
description: >-
(Optional) The maximum number of tokens to generate
max_tokens:
type: integer
description: >-
(Optional) The maximum number of tokens to generate
n:
type: integer
description: >-
(Optional) The number of completions to generate
parallel_tool_calls:
type: boolean
description: >-
(Optional) Whether to parallelize tool calls
presence_penalty:
type: number
description: >-
(Optional) The penalty for repeated tokens
response_format:
type: object
additionalProperties:
type: string
description: (Optional) The response format to use
seed:
type: integer
description: (Optional) The seed to use
stop:
oneOf:
- type: string
- type: array
items:
type: string
description: (Optional) The stop tokens to use
stream:
type: boolean
description: >-
(Optional) Whether to stream the response
stream_options:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) The stream options to use
temperature:
type: number
description: (Optional) The temperature to use
tool_choice:
oneOf:
- type: string
- type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) The tool choice to use
tools:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) The tools to use
top_logprobs:
type: integer
description: >-
(Optional) The top log probabilities to use
top_p:
type: number
description: (Optional) The top p to use
user:
type: string
description: (Optional) The user to use
additionalProperties: false
required:
- model
- messages
title: OpenaiChatCompletionRequest
OpenAIChatCompletion:
type: object
properties:
id:
type: string
description: The ID of the chat completion
choices:
type: array
items:
$ref: '#/components/schemas/OpenAIChoice'
description: List of choices
object:
type: string
const: chat.completion
default: chat.completion
description: >-
The object type, which will be "chat.completion"
created:
type: integer
description: >-
The Unix timestamp in seconds when the chat completion was created
model:
type: string
description: >-
The model that was used to generate the chat completion
additionalProperties: false
required:
- id
- choices
- object
- created
- model
title: OpenAIChatCompletion
description: >-
Response from an OpenAI-compatible chat completion request.
OpenAIChoice:
type: object
properties:
message:
$ref: '#/components/schemas/OpenAIMessageParam'
description: The message from the model
finish_reason:
type: string
description: The reason the model stopped generating
index:
type: integer
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
additionalProperties: false
required:
- message
- finish_reason
- index
title: OpenAIChoice
description: >-
A choice from an OpenAI-compatible chat completion response.
OpenAIChoiceLogprobs:
type: object
properties:
content:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
refusal:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
additionalProperties: false
title: OpenAIChoiceLogprobs
description: >-
The log probabilities for the tokens in the message from an OpenAI-compatible
chat completion response.
OpenAITokenLogProb:
type: object
properties:
token:
type: string
bytes:
type: array
items:
type: integer
logprob:
type: number
top_logprobs:
type: array
items:
$ref: '#/components/schemas/OpenAITopLogProb'
additionalProperties: false
required:
- token
- logprob
- top_logprobs
title: OpenAITokenLogProb
description: >-
The log probability for a token from an OpenAI-compatible chat completion
response.
OpenAITopLogProb:
type: object
properties:
token:
type: string
bytes:
type: array
items:
type: integer
logprob:
type: number
additionalProperties: false
required:
- token
- logprob
title: OpenAITopLogProb
description: >-
The top log probability for a token from an OpenAI-compatible chat completion
response.
OpenaiCompletionRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the model to use. The model must be registered with
Llama Stack and available via the /models endpoint.
prompt:
oneOf:
- type: string
- type: array
items:
type: string
- type: array
items:
type: integer
- type: array
items:
type: array
items:
type: integer
description: The prompt to generate a completion for
best_of:
type: integer
description: >-
(Optional) The number of completions to generate
echo:
type: boolean
description: (Optional) Whether to echo the prompt
frequency_penalty:
type: number
description: >-
(Optional) The penalty for repeated tokens
logit_bias:
type: object
additionalProperties:
type: number
description: (Optional) The logit bias to use
logprobs:
type: boolean
description: (Optional) The log probabilities to use
max_tokens:
type: integer
description: >-
(Optional) The maximum number of tokens to generate
n:
type: integer
description: >-
(Optional) The number of completions to generate
presence_penalty:
type: number
description: >-
(Optional) The penalty for repeated tokens
seed:
type: integer
description: (Optional) The seed to use
stop:
oneOf:
- type: string
- type: array
items:
type: string
description: (Optional) The stop tokens to use
stream:
type: boolean
description: >-
(Optional) Whether to stream the response
stream_options:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: (Optional) The stream options to use
temperature:
type: number
description: (Optional) The temperature to use
top_p:
type: number
description: (Optional) The top p to use
user:
type: string
description: (Optional) The user to use
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
additionalProperties: false
required:
- model
- prompt
title: OpenaiCompletionRequest
OpenAICompletion:
type: object
properties:
id:
type: string
choices:
type: array
items:
$ref: '#/components/schemas/OpenAICompletionChoice'
created:
type: integer
model:
type: string
object:
type: string
const: text_completion
default: text_completion
additionalProperties: false
required:
- id
- choices
- created
- model
- object
title: OpenAICompletion
description: >-
Response from an OpenAI-compatible completion request.
OpenAICompletionChoice:
type: object
properties:
finish_reason:
type: string
text:
type: string
index:
type: integer
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
additionalProperties: false
required:
- finish_reason
- text
- index
title: OpenAICompletionChoice
description: >-
A choice from an OpenAI-compatible completion response.
OpenAIModel:
type: object
properties:
id:
type: string
object:
type: string
const: model
default: model
created:
type: integer
owned_by:
type: string
additionalProperties: false
required:
- id
- object
- created
- owned_by
title: OpenAIModel
description: A model from OpenAI.
OpenAIListModelsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/OpenAIModel'
additionalProperties: false
required:
- data
title: OpenAIListModelsResponse
DPOAlignmentConfig:
type: object
properties:
@ -6079,10 +6726,13 @@ components:
type: integer
max_steps_per_epoch:
type: integer
default: 1
gradient_accumulation_steps:
type: integer
default: 1
max_validation_steps:
type: integer
default: 1
data_config:
$ref: '#/components/schemas/DataConfig'
optimizer_config:
@ -6097,9 +6747,6 @@ components:
- n_epochs
- max_steps_per_epoch
- gradient_accumulation_steps
- max_validation_steps
- data_config
- optimizer_config
title: TrainingConfig
PreferenceOptimizeRequest:
type: object
@ -6833,7 +7480,6 @@ components:
- training_config
- hyperparam_search_config
- logger_config
- model
title: SupervisedFineTuneRequest
SyntheticDataGenerateRequest:
type: object
@ -6968,6 +7614,17 @@ tags:
x-displayName: >-
Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon)
description: >-
This is an asynchronous API. If the request is successful, the response will
be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with
other asynchronous APIs
including (post-training, evals, etc).
x-displayName: >-
Batch inference API for generating completions and chat completions.
- name: Benchmarks
- name: DatasetIO
- name: Datasets

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -51,6 +51,7 @@ def main(output_dir: str):
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
)
print("")
spec = Specification(
LlamaStack,
Options(

View file

@ -519,7 +519,7 @@ class Generator:
)
def _build_extra_tag_groups(
self, extra_types: Dict[str, List[type]]
self, extra_types: Dict[str, Dict[str, type]]
) -> Dict[str, List[Tag]]:
"""
Creates a dictionary of tag group captions as keys, and tag lists as values.
@ -532,9 +532,8 @@ class Generator:
for category_name, category_items in extra_types.items():
tag_list: List[Tag] = []
for extra_type in category_items:
name = python_type_to_name(extra_type)
schema = self.schema_builder.classdef_to_named_schema(name, extra_type)
for name, extra_type in category_items.items():
schema = self.schema_builder.classdef_to_schema(extra_type)
tag_list.append(self._build_type_tag(name, schema))
if tag_list:
@ -863,7 +862,7 @@ class Generator:
for caption, extra_tag_group in extra_tag_groups.items():
tag_groups.append(
TagGroup(
name=self.options.map(caption),
name=caption,
tags=sorted(tag.name for tag in extra_tag_group),
)
)

View file

@ -2,6 +2,14 @@
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
## Render locally
```bash
pip install -r requirements.txt
cd docs
python -m sphinx_autobuild source _build
```
You can open up the docs in your browser at http://localhost:8000
## Content
Try out Llama Stack's capabilities through our detailed Jupyter notebooks:

View file

@ -3,10 +3,12 @@ myst-parser
linkify
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx-rtd-theme>=1.0.0
sphinx-pdj-theme
sphinx_autobuild
sphinx-copybutton
sphinx-tabs
sphinx-design
sphinx-pdj-theme
sphinx_rtd_dark_mode
sphinx-tabs
sphinxcontrib-openapi
sphinxcontrib-redoc
sphinxcontrib-mermaid

View file

@ -1,6 +1,9 @@
# Llama Stack Agent Framework
# Agents
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
An Agent in Llama Stack is a powerful abstraction that allows you to build complex AI applications.
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI
applications. This document explains the key components and how they work together.
## Core Concepts

View file

@ -1,6 +1,10 @@
## Agent Execution Loop
Agents are the heart of complex AI applications. They combine inference, memory, safety, and tool usage into coherent workflows. At its core, an agent follows a sophisticated execution loop that enables multi-step reasoning, tool usage, and safety checks.
Agents are the heart of Llama Stack applications. They combine inference, memory, safety, and tool usage into coherent
workflows. At its core, an agent follows a sophisticated execution loop that enables multi-step reasoning, tool usage,
and safety checks.
### Steps in the Agent Workflow
Each agent turn follows these key steps:
@ -64,7 +68,10 @@ sequenceDiagram
S->>U: 5. Final Response
```
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
Each step in this process can be monitored and controlled through configurations.
### Agent Execution Loop Example
Here's an example that demonstrates monitoring the agent's execution:
```python
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger

View file

@ -8,9 +8,9 @@ The best way to get started is to look at this notebook which walks through the
Here are some key topics that will help you build effective agents:
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
@ -20,12 +20,11 @@ Here are some key topics that will help you build effective agents:
:hidden:
:maxdepth: 1
rag
agent
agent_execution_loop
rag
tools
telemetry
evals
advanced_agent_patterns
telemetry
safety
```

View file

@ -3,9 +3,9 @@
RAG enables your applications to reference and recall information from previous interactions or external documents.
Llama Stack organizes the APIs that enable RAG into three layers:
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
1. The lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.).
2. The next is the "Rag Tool", a first-class tool as part of the [Tools API](tools.md) that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
3. Finally, it all comes together with the top-level ["Agents" API](agent.md) that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
<img src="rag.png" alt="RAG System" width="50%">
@ -17,14 +17,19 @@ We may add more storage types like Graph IO in the future.
### Setting up Vector DBs
For this guide, we will use [Ollama](https://ollama.com/) as the inference provider.
Ollama is an LLM runtime that allows you to run Llama models locally.
Here's how to set up a vector database for RAG:
```python
# Create http client
import os
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
# Register a vector db
vector_db_id = "my_documents"
response = client.vector_dbs.register(
@ -33,17 +38,27 @@ response = client.vector_dbs.register(
embedding_dimension=384,
provider_id="faiss",
)
```
### Ingesting Documents
You can ingest documents into the vector database using two methods: directly inserting pre-chunked
documents or using the RAG Tool.
```python
# You can insert a pre-chunked document directly into the vector db
chunks = [
{
"document_id": "doc1",
"content": "Your document text here",
"mime_type": "text/plain",
"metadata": {
"document_id": "doc1",
},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
```
### Retrieval
You can query the vector database to retrieve documents based on their embeddings.
```python
# You can then query for these chunks
chunks_response = client.vector_io.query(
vector_db_id=vector_db_id, query="What do you know about..."
@ -52,7 +67,8 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces.
```python
from llama_stack_client import RAGDocument

View file

@ -12,11 +12,12 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
from docutils import nodes
from pathlib import Path
import requests
import json
from datetime import datetime
from pathlib import Path
import requests
from docutils import nodes
# Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
@ -25,7 +26,9 @@ with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") a
print(f"{version_tag=}")
# generate the full link including text and url here
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
llama_stack_version_url = (
f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
)
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
project = "llama-stack"
@ -37,11 +40,11 @@ author = "Meta"
extensions = [
"myst_parser",
"sphinx_copybutton",
"sphinx_design",
"sphinx_rtd_theme",
"sphinx_rtd_dark_mode",
"sphinx_copybutton",
"sphinx_tabs.tabs",
"sphinx_design",
"sphinxcontrib.redoc",
"sphinxcontrib.mermaid",
"sphinxcontrib.video",
@ -85,7 +88,7 @@ myst_substitutions = {
"llama_stack_version_link": llama_stack_version_link,
}
suppress_warnings = ['myst.header']
suppress_warnings = ["myst.header"]
# Copy button settings
copybutton_prompt_text = "$ " # for bash prompts
@ -105,17 +108,21 @@ source_suffix = {
# html_theme = "alabaster"
html_theme_options = {
"canonical_url": "https://github.com/meta-llama/llama-stack",
'collapse_navigation': False,
"collapse_navigation": False,
# "style_nav_header_background": "#c3c9d4",
}
default_dark_mode = False
html_static_path = ["../_static"]
# html_logo = "../_static/llama-stack-logo.png"
# html_style = "../_static/css/my_theme.css"
def setup(app):
app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
url = f"https://hub.docker.com/r/llamastack/{text}"
node = nodes.reference(rawtext, text, refuri=url, **options)

View file

@ -231,7 +231,7 @@ options:
-h, --help show this help message and exit
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
--image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment (default: None)
Name of the image to run. Defaults to the current environment (default: None)
--disable-ipv6 Disable IPv6 support (default: False)
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
--tls-keyfile TLS_KEYFILE

View file

@ -2,7 +2,7 @@
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
```{dropdown} Sample Configuration File
```{dropdown} 👋 Click here for a Sample Configuration File
```yaml
version: 2

View file

@ -17,7 +17,7 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
)
await client.initialize()
client.initialize()
```
This will parse your config and set up any inline implementations and remote clients needed for your implementation.

View file

@ -7,13 +7,18 @@ In this guide, we'll use a local [Kind](https://kind.sigs.k8s.io/) cluster and a
First, create a local Kubernetes cluster via Kind:
```bash
```
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
```
First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
First set your hugging face token as an environment variable.
```
export HF_TOKEN=$(echo -n "your-hf-token" | base64)
```
```bash
Now create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
```
cat <<EOF |kubectl apply -f -
apiVersion: v1
kind: PersistentVolumeClaim
@ -33,13 +38,14 @@ metadata:
name: hf-token-secret
type: Opaque
data:
token: $(HF_TOKEN)
token: $HF_TOKEN
EOF
```
Next, start the vLLM server as a Kubernetes Deployment and Service:
```bash
```
cat <<EOF |kubectl apply -f -
apiVersion: apps/v1
kind: Deployment
@ -95,7 +101,7 @@ EOF
We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model):
```bash
```
$ kubectl logs -l app.kubernetes.io/name=vllm
...
INFO: Started server process [1]
@ -119,8 +125,8 @@ providers:
Once we have defined the run configuration for Llama Stack, we can build an image with that configuration and the server source code:
```bash
cat >/tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s <<EOF
```
tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <<EOF
FROM distribution-myenv:dev
RUN apt-get update && apt-get install -y git
@ -128,14 +134,14 @@ RUN git clone https://github.com/meta-llama/llama-stack.git /app/llama-stack-sou
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml
EOF
podman build -f /tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s /tmp/test-vllm-llama-stack
podman build -f $tmp_dir/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s $tmp_dir
```
### Deploying Llama Stack Server in Kubernetes
We can then start the Llama Stack server by deploying a Kubernetes Pod and Service:
```bash
```
cat <<EOF |kubectl apply -f -
apiVersion: v1
kind: PersistentVolumeClaim
@ -195,7 +201,7 @@ EOF
### Verifying the Deployment
We can check that the LlamaStack server has started:
```bash
```
$ kubectl logs -l app.kubernetes.io/name=llama-stack
...
INFO: Started server process [1]
@ -207,7 +213,7 @@ INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit
Finally, we forward the Kubernetes service to a local port and test some inference requests against it via the Llama Stack Client:
```bash
```
kubectl port-forward service/llama-stack-service 5000:5000
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
```

View file

@ -1,88 +0,0 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models
The following models are available by default:
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
### Prerequisite: API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
## Running Llama Stack with NVIDIA
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```
### Via Conda
```bash
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -46,6 +46,8 @@ The following models are available by default:
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
- `accounts/fireworks/models/llama4-scout-instruct-basic (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `accounts/fireworks/models/llama4-maverick-instruct-basic (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
- `nomic-ai/nomic-embed-text-v1.5 `

View file

@ -42,6 +42,8 @@ The following models are available by default:
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
### Prerequisite: API Keys

View file

@ -1,3 +1,4 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
@ -5,24 +6,49 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::nvidia` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models
The following models are available by default:
- `${env.INFERENCE_MODEL} (None)`
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
### Prerequisite: API Keys
@ -58,4 +84,5 @@ llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -25,7 +25,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
You can use this distribution if you want to run an independent vLLM server for inference.
### Environment Variables
@ -41,6 +41,83 @@ The following environment variables can be configured:
## Setting up vLLM server
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
that we only use GPUs here for demonstration purposes.
### Setting up vLLM server on AMD GPU
AMD provides two main vLLM container options:
- rocm/vllm: Production-ready container
- rocm/vllm-dev: Development container with the latest vLLM features
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
Here is a sample script to start a ROCm vLLM server locally via Docker:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
export VLLM_DIMG="rocm/vllm-dev:main"
docker run \
--pull always \
--ipc=host \
--privileged \
--shm-size 16g \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--cap-add=SYS_PTRACE \
--cap-add=CAP_SYS_ADMIN \
--security-opt seccomp=unconfined \
--security-opt apparmor=unconfined \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
-p $INFERENCE_PORT:$INFERENCE_PORT \
-v ~/.cache/huggingface:/root/.cache/huggingface \
$VLLM_DIMG \
python -m vllm.entrypoints.openai.api_server \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
export VLLM_DIMG="rocm/vllm-dev:main"
docker run \
--pull always \
--ipc=host \
--privileged \
--shm-size 16g \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--cap-add=SYS_PTRACE \
--cap-add=CAP_SYS_ADMIN \
--security-opt seccomp=unconfined \
--security-opt apparmor=unconfined \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
-p $SAFETY_PORT:$SAFETY_PORT \
-v ~/.cache/huggingface:/root/.cache/huggingface \
$VLLM_DIMG \
python -m vllm.entrypoints.openai.api_server \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
### Setting up vLLM server on NVIDIA GPU
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
```bash

View file

@ -43,6 +43,7 @@ The following models are available by default:
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
### Prerequisite: API Keys

View file

@ -48,6 +48,8 @@ The following models are available by default:
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
- `togethercomputer/m2-bert-80M-8k-retrieval `
- `togethercomputer/m2-bert-80M-32k-retrieval `
- `meta-llama/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct, together/meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct, together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)`
### Prerequisite: API Keys

View file

@ -2,22 +2,22 @@
You can run a Llama Stack server in one of the following ways:
**As a Library**:
## As a Library:
This is the simplest way to get started. Using Llama Stack as a library means you do not need to start a server. This is especially useful when you are not running inference locally and relying on an external inference service (eg. fireworks, together, groq, etc.) See [Using Llama Stack as a Library](importing_as_library)
**Container**:
## Container:
Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details.
**Conda**:
## Conda:
If you have a custom or an advanced setup or you are developing on Llama Stack you can also build a custom Llama Stack server. Using `llama stack build` and `llama stack run` you can build/run a custom Llama Stack server containing the exact combination of providers you wish. We have also provided various templates to make getting started easier. See [Building a Custom Distribution](building_distro) for more details.
**Kubernetes**:
## Kubernetes:
If you have built a container image and want to deploy it in a Kubernetes cluster instead of starting the Llama Stack server locally. See [Kubernetes Deployment Guide](kubernetes_deployment) for more details.

View file

@ -0,0 +1,541 @@
# Detailed Tutorial
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent.
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with
tools (e.g., RAG, web search, code execution, etc.) for taking actions.
In Llama Stack, we provide a server exposing multiple APIs. These APIs are backed by implementations from different providers.
Llama Stack is a stateful service with REST APIs to support seamless transition of AI applications across different environments. The server can be run in a variety of ways, including as a standalone binary, Docker container, or hosted service. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model.
## Step 1: Installation and Setup
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then
download Llama 3.2 3B model, and then start the Ollama service.
```bash
ollama pull llama3.2:3b
ollama run llama3.2:3b --keepalive 60m
```
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
::::{tab-set}
:::{tab-item} macOS and Linux
Use `curl` to download the script and execute it with `sh`:
```console
curl -LsSf https://astral.sh/uv/install.sh | sh
```
:::
:::{tab-item} Windows
Use `irm` to download the script and execute it with `iex`:
```console
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
```
:::
::::
Setup your virtual environment.
```bash
uv venv --python 3.10
source .venv/bin/activate
```
## Step 2: Run Llama Stack
Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK.
::::{tab-set}
:::{tab-item} Using `venv`
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings.
Now let's build and run the Llama Stack config for Ollama.
```bash
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type venv --run
```
:::
:::{tab-item} Using `conda`
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings.
Now let's build and run the Llama Stack config for Ollama.
```bash
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type conda --image-name llama3-3b-conda --run
```
:::
:::{tab-item} Using a Container
You can use a container image to run the Llama Stack server. We provide several container images for the server
component that works with different inference providers out of the box. For this guide, we will use
`llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the
configurations, please check out [this guide](../references/index.md).
First lets setup some environment variables and create a local directory to mount into the containers file system.
```bash
export INFERENCE_MODEL="llama3.2:3b"
export LLAMA_STACK_PORT=8321
mkdir -p ~/.llama
```
Then start the server using the container tool of your choice. For example, if you are running Docker you can use the
following command:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434
```
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
with `host.containers.internal`.
The configuration YAML for the Ollama distribution is available at `distributions/ollama/run.yaml`.
```{tip}
Docker containers run in their own isolated network namespaces on Linux. To allow the container to communicate with services running on the host via `localhost`, you need `--network=host`. This makes the container use the hosts network directly so it can connect to Ollama running on `localhost:11434`.
Linux users having issues running the above command should instead try the following:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
--network=host \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434
```
:::
::::
You will see output like below:
```
INFO: Application startup complete.
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
```
Now you can use the Llama Stack client to run inference and build agents!
You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/).
Note that the client package is already included in the `llama-stack` package.
## Step 3: Run Client CLI
Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your
existing server virtual environment.
::::{tab-set}
:::{tab-item} Reuse Server `venv`
```bash
# The client is included in the llama-stack package so we just activate the server venv
source .venv/bin/activate
```
:::
:::{tab-item} Install with `venv`
```bash
uv venv client --python 3.10
source client/bin/activate
pip install llama-stack-client
```
:::
:::{tab-item} Install with `conda`
```bash
yes | conda create -n stack-client python=3.10
conda activate stack-client
pip install llama-stack-client
```
:::
::::
Now let's use the `llama-stack-client` [CLI](../references/llama_stack_client_cli_reference.md) to check the
connectivity to the server.
```bash
llama-stack-client configure --endpoint http://localhost:8321 --api-key none
```
You will see the below:
```
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
```
List the models
```bash
llama-stack-client models list
Available Models
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ embedding │ all-MiniLM-L6-v2 │ all-minilm:latest │ {'embedding_dimension': 384.0} │ ollama │
├─────────────────┼─────────────────────────────────────┼─────────────────────────────────────┼───────────────────────────────────────────┼─────────────────┤
│ llm │ llama3.2:3b │ llama3.2:3b │ │ ollama │
└─────────────────┴─────────────────────────────────────┴─────────────────────────────────────┴───────────────────────────────────────────┴─────────────────┘
Total models: 2
```
You can test basic Llama inference completion using the CLI.
```bash
llama-stack-client inference chat-completion --message "tell me a joke"
```
Sample output:
```python
ChatCompletionResponse(
completion_message=CompletionMessage(
content="Here's one:\n\nWhat do you call a fake noodle?\n\nAn impasta!",
role="assistant",
stop_reason="end_of_turn",
tool_calls=[],
),
logprobs=None,
metrics=[
Metric(metric="prompt_tokens", value=14.0, unit=None),
Metric(metric="completion_tokens", value=27.0, unit=None),
Metric(metric="total_tokens", value=41.0, unit=None),
],
)
```
## Step 4: Run the Demos
Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md).
Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options.
::::{tab-set}
:::{tab-item} Basic Inference
Now you can run inference using the Llama Stack client SDK.
### i. Create the Script
Create a file `inference.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
# List available models
models = client.models.list()
# Select the first LLM
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
print("Model:", model_id)
response = client.inference.chat_completion(
model_id=model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"},
],
)
print(response.completion_message.content)
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python inference.py
```
Which will output:
```
Model: llama3.2:3b
Here is a haiku about coding:
Lines of code unfold
Logic flows through digital night
Beauty in the bits
```
:::
:::{tab-item} Build a Simple Agent
Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server.
### i. Create the Script
Create a file `agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from rich.pretty import pprint
import uuid
client = LlamaStackClient(base_url=f"http://localhost:8321")
models = client.models.list()
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
s_id = agent.create_session(session_name=f"s{uuid.uuid4().hex}")
print("Non-streaming ...")
response = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}],
session_id=s_id,
stream=False,
)
print("agent>", response.output_message.content)
print("Streaming ...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in stream:
pprint(event)
print("Streaming with print helper...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python agent.py
```
```{dropdown} 👋 Click here to see the sample output
Non-streaming ...
agent> I'm an artificial intelligence designed to assist and communicate with users like you. I don't have a personal identity, but I'm here to provide information, answer questions, and help with tasks to the best of my abilities.
I can be used for a wide range of purposes, such as:
* Providing definitions and explanations
* Offering suggestions and ideas
* Helping with language translation
* Assisting with writing and proofreading
* Generating text or responses to questions
* Playing simple games or chatting about topics of interest
I'm constantly learning and improving my abilities, so feel free to ask me anything, and I'll do my best to help!
Streaming ...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepStartPayload(
│ │ │ event_type='step_start',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference',
│ │ │ metadata={}
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text='As', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text=' a', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepCompletePayload(
│ │ │ event_type='step_complete',
│ │ │ step_details=InferenceStep(
│ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ step_type='inference',
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ ),
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseTurnCompletePayload(
│ │ │ event_type='turn_complete',
│ │ │ turn=Turn(
│ │ │ │ input_messages=[UserMessage(content='Who are you?', role='user', context=None)],
│ │ │ │ output_message=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ session_id='abd4afea-4324-43f4-9513-cfe3970d92e8',
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28722, tzinfo=TzInfo(UTC)),
│ │ │ │ steps=[
│ │ │ │ │ InferenceStep(
│ │ │ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ │ │ role='assistant',
│ │ │ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ │ │ tool_calls=[]
│ │ │ │ │ │ ),
│ │ │ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ │ │ step_type='inference',
│ │ │ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ │ │ )
│ │ │ │ ],
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 727364, tzinfo=TzInfo(UTC)),
│ │ │ │ output_attachments=[]
│ │ │ )
│ │ )
│ )
)
Streaming with print helper...
inference> Déjà vu!
As I mentioned earlier, I'm an artificial intelligence language model. I don't have a personal identity or consciousness like humans do. I exist solely to process and respond to text-based inputs, providing information and assistance on a wide range of topics.
I'm a computer program designed to simulate human-like conversations, using natural language processing (NLP) and machine learning algorithms to understand and generate responses. My purpose is to help users like you with their questions, provide information, and engage in conversation.
Think of me as a virtual companion, a helpful tool designed to make your interactions more efficient and enjoyable. I don't have personal opinions, emotions, or biases, but I'm here to provide accurate and informative responses to the best of my abilities.
So, who am I? I'm just a computer program designed to help you!
```
:::
:::{tab-item} Build a RAG Agent
For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents
in a vector database.
### i. Create the Script
Create a file `rag_agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from llama_stack_client.types import Document
import uuid
from termcolor import cprint
client = LlamaStackClient(base_url="http://localhost:8321")
# Create a vector database instance
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
)
# Create Documents
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
# Insert documents
client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
# Get the model being served
llm = next(m for m in client.models.list() if m.model_type == "llm")
model = llm.identifier
# Create the RAG agent
rag_agent = Agent(
client,
model=model,
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
turns = ["what is torchtune", "tell me about dora"]
for t in turns:
print("user>", t)
stream = rag_agent.create_turn(
messages=[{"role": "user", "content": t}], session_id=session_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python rag_agent.py
```
```{dropdown} 👋 Click here to see the sample output
user> what is torchtune
inference> [knowledge_search(query='TorchTune')]
tool_execution> Tool:knowledge_search Args:{'query': 'TorchTune'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text='Result 1:\nDocument_id:num-1\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. ..., type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Here is a high-level overview of the text:
**LoRA Finetuning with PyTorch Tune**
PyTorch Tune provides a recipe for LoRA (Low-Rank Adaptation) finetuning, which is a technique to adapt pre-trained models to new tasks. The recipe uses the `lora_finetune_distributed` command.
...
Overall, DORA is a powerful reinforcement learning algorithm that can learn complex tasks from human demonstrations. However, it requires careful consideration of the challenges and limitations to achieve optimal results.
```
:::
::::
**You're Ready to Build Your Own Apps!**
Congrats! 🥳 Now you're ready to [build your own Llama Stack applications](../building_applications/index)! 🚀

View file

@ -1,304 +1,121 @@
# Quick Start
# Quickstart
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to build a simple [RAG (Retrieval Augmented Generation)](../building_applications/rag.md) agent.
Get started with Llama Stack in minutes!
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with tools (e.g., RAG, web search, code execution, etc.) for taking actions.
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In Llama Stack, we provide a server exposing multiple APIs. These APIs are backed by implementations from different providers. For this guide, we will use [Ollama](https://ollama.com/) as the inference provider.
Ollama is an LLM runtime that allows you to run Llama models locally.
### 1. Start Ollama
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model.
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama run llama3.2:3b-instruct-fp16 --keepalive 60m
ollama run llama3.2:3b --keepalive 60m
```
By default, Ollama keeps the model loaded in memory for 5 minutes which can be too short. We set the `--keepalive` flag to 60 minutes to ensure the model remains loaded for sometime.
```{admonition} Note
:class: tip
If you do not have ollama, you can install it from [here](https://ollama.com/download).
```
### 2. Pick a client environment
Llama Stack has a service-oriented architecture, so every interaction with the Stack happens through a REST interface. You can interact with the Stack in two ways:
* Install the `llama-stack-client` PyPI package and point `LlamaStackClient` to a local or remote Llama Stack server.
* Or, install the `llama-stack` PyPI package and use the Stack as a library using `LlamaStackAsLibraryClient`.
```{admonition} Note
:class: tip
The API is **exactly identical** for both clients.
```
:::{dropdown} Starting up the Llama Stack server
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
Lets setup some environment variables that we will use in the rest of the guide.
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export LLAMA_STACK_PORT=8321
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template ollama --image-type venv --run
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
Next you can create a local directory to mount into the containers file system.
```bash
mkdir -p ~/.llama
```
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434
```
As another example, to start the container with Podman, you can do the same but replace `docker` at the start of the command with `podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL` with `host.containers.internal`.
Configuration for this is available at `distributions/ollama/run.yaml`.
```{admonition} Note
:class: note
Docker containers run in their own isolated network namespaces on Linux. To allow the container to communicate with services running on the host via `localhost`, you need `--network=host`. This makes the container use the hosts network directly so it can connect to Ollama running on `localhost:11434`.
Linux users having issues running the above command should instead try the following:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
--network=host \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434
```
:::
:::{dropdown} Installing the Llama Stack client CLI and SDK
You can interact with the Llama Stack server using various client SDKs. Note that you must be using Python 3.10 or newer. We will use the Python SDK which you can install via `conda` or `virtualenv`.
For `conda`:
```bash
yes | conda create -n stack-client python=3.10
conda activate stack-client
pip install llama-stack-client
```
For `virtualenv`:
```bash
python -m venv stack-client
source stack-client/bin/activate
pip install llama-stack-client
```
Let's use the `llama-stack-client` CLI to check the connectivity to the server.
```bash
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
> Enter the API key (leave empty if no key is needed):
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
$ llama-stack-client models list
Available Models
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
Total models: 1
```
You can test basic Llama inference completion using the CLI too.
```bash
llama-stack-client \
inference chat-completion \
--message "hello, what model are you?"
```
:::
&nbsp;
### 3. Run inference with Python SDK
Here is a simple example to perform chat completions using the SDK.
```python
import os
import sys
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
)
def create_library_client(template="ollama"):
from llama_stack import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient(template)
if not client.initialize():
print("llama stack not built properly")
sys.exit(1)
return client
client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# List available models
models = client.models.list()
print("--- Available models: ---")
for m in models:
print(f"- {m.identifier}")
print()
response = client.inference.chat_completion(
model_id=os.environ["INFERENCE_MODEL"],
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"},
],
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
print(response.completion_message.content)
```
To run the above example, put the code in a file called `inference.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
```bash
pip install llama_stack
llama stack build --template ollama --image-type <conda|venv>
python inference.py
```
### 4. Your first RAG agent
Here is an example of a simple RAG (Retrieval Augmented Generation) chatbot agent which can answer questions about TorchTune documentation.
```python
import os
import uuid
from termcolor import cprint
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
)
def create_library_client(template="ollama"):
from llama_stack import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient(template)
client.initialize()
return client
client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
for i, url in enumerate(urls)
]
vector_providers = [
provider for provider in client.providers.list() if provider.api == "vector_io"
]
provider_id = vector_providers[0].provider_id # Use the first available vector provider
# Register a vector database
vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"
client.vector_dbs.register(
vector_db_id=vector_db_id,
provider_id=provider_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
# Insert the documents into the vector database
client.tool_runtime.rag_tool.insert(
documents=documents,
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
chunk_size_in_tokens=50,
)
rag_agent = Agent(
agent = Agent(
client,
model=os.environ["INFERENCE_MODEL"],
# Define instructions for the agent ( aka system prompt)
model=model_id,
instructions="You are a helpful assistant",
enable_session_persistence=False,
# Define tools available to the agent
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
session_id = rag_agent.create_session("test-session")
user_prompts = [
"How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
]
prompt = "How do you do great work?"
print("prompt>", prompt)
# Run the agent loop by calling the `create_turn` method
for prompt in user_prompts:
cprint(f"User> {prompt}", "green")
response = rag_agent.create_turn(
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
)
for log in AgentEventLogger().log(response):
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
```
To run the above example, put the code in a file called `rag.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
```bash
pip install llama_stack
llama stack build --template ollama --image-type <conda|venv>
python rag.py
We will use `uv` to run the script
```
uv run --with llama-stack-client demo_script.py
```
And you should see output like below.
```
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
prompt> How do you do great work?
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
## Next Steps
- Learn more about Llama Stack [Concepts](../concepts/index.md)
- Learn how to [Build Llama Stacks](../distributions/index.md)
- See [References](../references/index.md) for more details about the llama CLI and Python SDK
- For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository.
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -1,10 +1,16 @@
# Llama Stack
Welcome to Llama Stack, the open-source framework for building generative AI applications.
```{admonition} Llama 4 is here!
:class: tip
Check out [Getting Started with Llama 4](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started_llama4.ipynb)
```
```{admonition} News
:class: tip
Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
```
# Llama Stack
## What is Llama Stack?
@ -24,19 +30,17 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge
Our goal is to provide pre-packaged implementations (aka "distributions") which can be run in a variety of deployment environments. LlamaStack can assist you in your entire app development lifecycle - start iterating on local, mobile or desktop and seamlessly transition to on-prem or public cloud deployments. At every point in this transition, the same set of APIs and the same developer experience is available.
## How does Llama Stack work?
Llama Stack consists of a [server](./distributions/index.md) (with multiple pluggable API [providers](./providers/index.md)) and [client SDKs](#available-sdks) meant to
Llama Stack consists of a [server](./distributions/index.md) (with multiple pluggable API [providers](./providers/index.md)) and Client SDKs (see below) meant to
be used in your applications. The server can be run in a variety of environments, including local (inline)
development, on-premises, and cloud. The client SDKs are available for Python, Swift, Node, and
Kotlin.
## Quick Links
- New to Llama Stack? Start with the [Introduction](introduction/index) to understand our motivation and vision.
- Ready to build? Check out the [Quick Start](getting_started/index) to get started.
- Need specific providers? Browse [Distributions](distributions/selection) to see all the options available.
- Want to contribute? See the [Contributing](contributing/index) guide.
## Available SDKs
## Client SDKs
We have a number of client-side SDKs available for different languages.
@ -95,8 +99,9 @@ A number of "adapters" are available for some popular Inference and Vector Store
:maxdepth: 3
self
introduction/index
getting_started/index
getting_started/detailed_tutorial
introduction/index
concepts/index
providers/index
distributions/index

View file

@ -103,7 +103,5 @@ llama stack run together
2. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
```

View file

@ -0,0 +1,234 @@
# External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
- Create and maintain your own providers independently
- Share providers with others without contributing to the main codebase
- Keep provider-specific code separate from the core Llama Stack code
## Configuration
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
```
## Directory Structure
The external providers directory should follow this structure:
```
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
```
Each YAML file in these directories defines a provider specification for that particular API.
## Provider Types
Llama Stack supports two types of external providers:
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
2. **Inline Providers**: Providers that run locally within the Llama Stack process
## Known External Providers
Here's a list of known external providers that you can use with Llama Stack:
| Type | Name | Description | Repository |
|------|------|-------------|------------|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
### Remote Provider Specification
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
```yaml
adapter:
adapter_type: custom_ollama
pip_packages:
- ollama
- aiohttp
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider
api_dependencies: []
optional_api_dependencies: []
```
#### Adapter Configuration
The `adapter` section defines how to load and configure the provider:
- `adapter_type`: A unique identifier for this adapter
- `pip_packages`: List of Python packages required by the provider
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
### Inline Provider Specification
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
```yaml
module: llama_stack_vector_provider
config_class: llama_stack_vector_provider.config.VectorStoreConfig
pip_packages:
- faiss-cpu
- numpy
api_dependencies:
- inference
optional_api_dependencies:
- vector_io
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
container_image: custom-vector-store:latest # optional
```
#### Inline Provider Fields
- `module`: The Python module containing the provider implementation
- `config_class`: The full path to the configuration class
- `pip_packages`: List of Python packages required by the provider
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `provider_data_validator`: Optional validator for provider data
- `container_image`: Optional container image to use instead of pip packages
## Required Implementation
### Remote Providers
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
1. `config`: An instance of the provider's config class
2. `deps`: A dictionary of API dependencies
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
Example:
```python
async def get_adapter_impl(
config: OllamaImplConfig, deps: Dict[Api, Any]
) -> OllamaInferenceAdapter:
return OllamaInferenceAdapter(config)
```
### Inline Providers
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
1. `config`: An instance of the provider's config class
2. `deps`: A dictionary of API dependencies
Example:
```python
async def get_provider_impl(
config: VectorStoreConfig, deps: Dict[Api, Any]
) -> VectorStoreImpl:
impl = VectorStoreImpl(config, deps[Api.inference])
await impl.initialize()
return impl
```
## Dependencies
The provider package must be installed on the system. For example:
```bash
$ uv pip show llama-stack-ollama-provider
Name: llama-stack-ollama-provider
Version: 0.1.0
Location: /path/to/venv/lib/python3.10/site-packages
```
## Example: Custom Ollama Provider
Here's a complete example of creating and using a custom Ollama provider:
1. First, create the provider package:
```bash
mkdir -p llama-stack-provider-ollama
cd llama-stack-provider-ollama
git init
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-provider-ollama"
version = "0.1.0"
description = "Ollama provider for Llama Stack"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
```
3. Create the provider specification:
```yaml
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []
```
4. Install the provider:
```bash
uv pip install -e .
```
5. Configure Llama Stack to use external providers:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
## Best Practices
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
4. **Documentation**: Include clear documentation in your provider package about:
- Installation requirements
- Configuration options
- Usage examples
- Any limitations or known issues
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
You can refer to the [integration tests
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
information. Execute the test for the Provider type you are developing.
## Troubleshooting
If your external provider isn't being loaded:
1. Check that the `external_providers_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the provider package is installed in your Python environment.

View file

@ -1,8 +1,8 @@
# Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- LLM inference providers (e.g., Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, SQLite-Vec, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:
@ -11,6 +11,10 @@ Providers come in two flavors:
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
## Agents
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
```{toctree}
:maxdepth: 1
external
vector_io/faiss
vector_io/sqlite-vec
vector_io/chromadb

View file

@ -6,11 +6,8 @@
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
) -> Job: ...

View file

@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class LogProbConfig(BaseModel):
"""
@ -48,18 +97,18 @@ class QuantizationType(Enum):
"""Type of model quantization to run inference with.
:cvar bf16: BFloat16 typically this means _no_ quantization
:cvar fp8: 8-bit floating point quantization
:cvar int4: 4-bit integer quantization
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
:cvar int4_mixed: 4-bit integer quantization with mixed precision
"""
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal["fp8"] = "fp8"
type: Literal["fp8_mixed"] = "fp8_mixed"
@json_schema_type
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
"""
type: Literal["int4"] = "int4"
type: Literal["int4_mixed"] = "int4_mixed"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
@ -393,6 +442,217 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
"""A message from the user in an OpenAI-compatible chat completion request.
:param role: Must be "user" to identify this as a user message
:param content: The content of the message, which can include text and other media
:param name: (Optional) The name of the user message participant.
"""
role: Literal["user"] = "user"
content: InterleavedContent
name: Optional[str] = None
@json_schema_type
class OpenAISystemMessageParam(BaseModel):
"""A system message providing instructions or context to the model.
:param role: Must be "system" to identify this as a system message
:param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions).
:param name: (Optional) The name of the system message participant.
"""
role: Literal["system"] = "system"
content: InterleavedContent
name: Optional[str] = None
@json_schema_type
class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param name: (Optional) The name of the assistant message participant.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
@json_schema_type
class OpenAIToolMessageParam(BaseModel):
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request.
:param role: Must be "tool" to identify this as a tool response
:param tool_call_id: Unique identifier for the tool call this response is for
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool"
tool_call_id: str
content: InterleavedContent
@json_schema_type
class OpenAIDeveloperMessageParam(BaseModel):
"""A message from the developer in an OpenAI-compatible chat completion request.
:param role: Must be "developer" to identify this as a developer message
:param content: The content of the developer message
:param name: (Optional) The name of the developer message participant.
"""
role: Literal["developer"] = "developer"
content: InterleavedContent
name: Optional[str] = None
OpenAIMessageParam = Annotated[
Union[
OpenAIUserMessageParam,
OpenAISystemMessageParam,
OpenAIAssistantMessageParam,
OpenAIToolMessageParam,
OpenAIDeveloperMessageParam,
],
Field(discriminator="role"),
]
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
"""
token: str
bytes: Optional[List[int]] = None
logprob: float
@json_schema_type
class OpenAITokenLogProb(BaseModel):
"""The log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
:top_logprobs: The top log probabilities for the token
"""
token: str
bytes: Optional[List[int]] = None
logprob: float
top_logprobs: List[OpenAITopLogProb]
@json_schema_type
class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
:content: (Optional) The log probabilities for the tokens in the message
:refusal: (Optional) The log probabilities for the tokens in the message
"""
content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None
@json_schema_type
class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response.
:param message: The message from the model
:param finish_reason: The reason the model stopped generating
:index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the message
"""
message: OpenAIMessageParam
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChoice]
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
@json_schema_type
class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
:text_offset: (Optional) The offset of the token in the text
:token_logprobs: (Optional) The log probabilities for the tokens
:tokens: (Optional) The tokens
:top_logprobs: (Optional) The top log probabilities for the tokens
"""
text_offset: Optional[List[int]] = None
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[str]] = None
top_logprobs: Optional[List[Dict[str, float]]] = None
@json_schema_type
class OpenAICompletionChoice(BaseModel):
"""A choice from an OpenAI-compatible completion response.
:finish_reason: The reason the model stopped generating
:text: The text of the choice
:index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the choice
"""
finish_reason: str
text: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAICompletion(BaseModel):
"""Response from an OpenAI-compatible completion request.
:id: The ID of the completion
:choices: List of choices
:created: The Unix timestamp in seconds when the completion was created
:model: The model that was used to generate the completion
:object: The object type, which will be "text_completion"
"""
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
object: Literal["text_completion"] = "text_completion"
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
@ -421,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@ -456,6 +726,17 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
@ -496,6 +777,19 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
@ -515,3 +809,105 @@ class Inference(Protocol):
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
# vLLM-specific parameters
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for
:param best_of: (Optional) The number of completions to generate
:param echo: (Optional) Whether to echo the prompt
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param presence_penalty: (Optional) The penalty for repeated tokens
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST")
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param function_call: (Optional) The function call to use
:param functions: (Optional) List of functions to use
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
:param presence_penalty: (Optional) The penalty for repeated tokens
:param response_format: (Optional) The response format to use
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param tool_choice: (Optional) The tool choice to use
:param tools: (Optional) The tools to use
:param top_logprobs: (Optional) The top log probabilities to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
"""
...

View file

@ -56,12 +56,35 @@ class ListModelsResponse(BaseModel):
data: List[Model]
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI.
:id: The ID of the model
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class OpenAIListModelsResponse(BaseModel):
data: List[OpenAIModel]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,

View file

@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: Optional[int] = 1
data_config: Optional[DataConfig] = None
optimizer_config: Optional[OptimizerConfig] = None
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@ -177,9 +177,9 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
model: Optional[str] = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,

View file

@ -29,8 +29,8 @@ from rich.progress import (
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand):
@ -162,6 +162,10 @@ class ParallelDownloader:
raise last_exception
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
if task.total_size > 0:
self.progress.update(task.task_id, total=task.total_size)
return
async def _get_info():
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
@ -282,7 +286,7 @@ class ParallelDownloader:
if not tasks:
raise ValueError("No download tasks provided")
if not self.has_disk_space(tasks):
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []

View file

@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(
(
"Recommended sampling params",
json.dumps(sampling_params, indent=4),
)
)
print_table(
rows,
headers,

View file

@ -11,7 +11,7 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None
def descriptor(self) -> str:
return self.model_id

View file

@ -57,7 +57,7 @@ class StackBuild(Subcommand):
type=str,
help=textwrap.dedent(
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
the build. If not specified, currently active Conda environment will be used if found.
the build. If not specified, currently active environment will be used if found.
"""
),
default=None,

View file

@ -45,7 +45,7 @@ class StackRun(Subcommand):
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current conda environment",
help="Name of the image to run. Defaults to the current environment",
)
self.parser.add_argument(
"--disable-ipv6",

View file

@ -312,6 +312,11 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Optional[str] = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -4,12 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
from typing import Dict, List
import os
from typing import Any, Dict, List
import yaml
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
def stack_apis() -> List[Api]:
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
Args:
config: Optional StackRunConfig containing the external providers directory path
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
if config and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret

View file

@ -273,7 +273,6 @@ def sort_providers_by_deps(
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
@ -351,6 +350,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
@ -399,6 +399,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if value.__webmethod__.experimental:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):

View file

@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
@ -35,6 +37,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
@ -333,6 +336,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion(
self,
model_id: str,
@ -397,6 +424,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,
@ -419,6 +460,126 @@ class InferenceRouter(Inference):
task_type=task_type,
)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
class SafetyRouter(Safety):
def __init__(

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import logging
import time
import uuid
from typing import Any, Dict, List, Optional
@ -23,7 +24,7 @@ from llama_stack.apis.datasets import (
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
@ -254,6 +255,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
@ -608,8 +622,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = (await self.list_tools(toolgroup_id)).data
for tool in tools:
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -96,7 +96,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
await method(**obj.model_dump())
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
@ -218,7 +221,7 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
await register_resources(run_config, impls)
return impls

View file

@ -18,6 +18,7 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
@ -73,7 +74,7 @@ done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment

View file

@ -1,7 +1,7 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.9-slim
FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \

View file

@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
3. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
```
## Environment Variables

View file

@ -24,6 +24,7 @@ def main():
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
@ -39,6 +40,7 @@ def main():
"Playground": [
chat_page,
rag_page,
tool_page,
application_evaluation_page,
native_evaluation_page,
],

View file

@ -19,6 +19,7 @@ class LlamaStackApi:
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
)

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -14,9 +17,16 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents")
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
@ -27,11 +37,11 @@ def rag_chat_page():
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Vector Database Name",
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this vector database",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Vector Database"):
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
@ -62,26 +72,45 @@ def rag_chat_page():
)
st.success("Vector database created successfully!")
st.subheader("Configure Agent")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
"Select Vector Databases",
vector_dbs,
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
@ -90,6 +119,8 @@ def rag_chat_page():
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
@ -98,19 +129,23 @@ def rag_chat_page():
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.messages:
for message in st.session_state.displayed_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
@ -123,7 +158,9 @@ def rag_chat_page():
else:
strategy = {"type": "greedy"}
agent = Agent(
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
@ -139,17 +176,19 @@ def rag_chat_page():
)
],
)
session_id = agent.create_session("rag-session")
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
@ -177,6 +216,79 @@ def rag_chat_page():
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Display the retrieved content
retrieval_response += str(prompt_context)
retrieval_message_placeholder.info(retrieval_response)
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack.distribution.ui.modules.api import llama_stack_api
def tool_chat_page():
st.title("🛠 Tools")
client = llama_stack_api.client
models = client.models.list()
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
tool_groups = client.toolgroups.list()
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
def reset_agent():
st.session_state.clear()
st.cache_resource.clear()
with st.sidebar:
st.subheader("Model")
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
st.subheader("Builtin Tools")
toolgroup_selection = st.pills(
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
)
st.subheader("MCP Servers")
mcp_selection = st.pills(
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
)
toolgroup_selection.extend(mcp_selection)
active_tool_list = []
for toolgroup_id in toolgroup_selection:
active_tool_list.extend(
[
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
for t in client.tools.list(toolgroup_id=toolgroup_id)
]
)
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
st.json(active_tool_list)
@st.cache_resource
def create_agent():
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={
"strategy": {"type": "greedy"},
},
)
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input(placeholder=""):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
def response_generator(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
if response.event.payload.event_type == "step_progress":
if hasattr(response.event.payload.delta, "text"):
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
yield " 🛠 "
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response})
tool_chat_page()

View file

@ -1,4 +1,5 @@
streamlit
pandas
llama-stack-client>=0.0.55
llama-stack-client>=0.2.1
streamlit-option-menu
llama-stack>=0.2.1

View file

@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
# Update our tracked values with any changes made during this iteration
for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get()
yield item
except StopAsyncIteration:

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64
from enum import Enum
from io import BytesIO
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
@ -98,6 +89,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
@ -140,267 +154,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall)
class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
source: Literal["input"] | Literal["output"]
# index within the batch
batch_idx: int
# whether generation for this item is already finished. note that tokens can
# get returned even afterwards since other items in the batch can still be generating tokens
finished: bool
# because a batch is parallel processed, useful decoding for one item can correspond to processing
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
ignore_token: bool
@json_schema_type
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import json
import uuid
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils
@ -234,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -0,0 +1,371 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
class Llama3:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
device = torch.device(device)
if (
device.type == "cuda"
and not torch.cuda.is_available()
or device.type == "xpu"
and not torch.xpu.is_available()
):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
if not torch.distributed.is_initialized():
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
print(f"Setting default device to {device}")
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama3(model, tokenizer, model_args)
def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
text_only_inference,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -16,7 +16,7 @@ from typing import List, Optional
from termcolor import colored
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import (

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Optional, Tuple
@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import logging
import math
from functools import partial
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads,
):
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1
if model_parallel_size > 16:
qkvo_replication = model_parallel_size // 8
if world_size > 16:
qkvo_replication = world_size // 8
self.n_kv_heads = n_heads
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
self.n_local_heads = n_heads * qkvo_replication // world_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if model_parallel_size > 8:
replication_factor = model_parallel_size // MP_SCALE
if world_size > 8:
replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len
@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads,
self.head_dim,
)
device = next(self.parameters()).device
self.register_buffer(
"key_cache",
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor,
position_ids: torch.LongTensor,
):
self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device)
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float,
):
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if self.model_parallel_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE
if self.world_size > 8:
replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya
# local heads
assert self.n_heads % self.n_kv_heads == 0
assert self.n_heads % self.model_parallel_size == 0
assert self.n_kv_heads % self.model_parallel_size == 0
self.n_local_heads = self.n_heads // self.model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.n_heads % self.world_size == 0
assert self.n_kv_heads % self.world_size == 0
self.n_local_heads = self.n_heads // self.world_size
self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None:
return_intermediate = [int(level) for level in return_intermediate.split(",")]
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14
self.vision_encoder = VisionEncoder(
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.vocab_size % self.model_parallel_size == 0
self.n_local_kv_heads = self.n_kv_heads // self.world_size
assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None
# final norm layer (not necessary for post-norm)
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False,
):
assert self.cache_is_setup, "Please set up cache before calling forward"
self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids)
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output)
return output.float()
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones(
(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return (
cross_attention_masks.to(device=text_device, dtype=text_dtype),
full_text_row_masked_out_mask,
full_text_row_masked_out_mask.to(device=text_device),
)
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks,
)
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, dtype)
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size,
max_num_images=max_num_images,
)
stacked_images = stacked_images.to(device=device)
if skip_vision_encoder:
vision_tokens = torch.zeros(
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
),
)
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import (
from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
@ -229,6 +229,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
If you decide to invoke a function, you SHOULD NOT include any other text in the response. besides the function call in the above format.
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
{{ function_description }}
""".strip("\n")
)
@ -243,10 +248,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
@ -279,6 +280,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{% endif -%}
{%- endfor %}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
"""
)
return PromptTemplate(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -4,12 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
# type: ignore
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -18,52 +15,53 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.sku_list import resolve_model
from ...llama3.args import ModelArgs
from ...llama3.model import Transformer, TransformerBlock
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
from ...datatypes import QuantizationMode
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ..model import Transformer, TransformerBlock
from ..multimodal.model import CrossAttentionTransformer
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
if os.path.isfile(fp8_scales_path):
print("Loading fp8 scales...")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@ -77,23 +75,23 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
print("Quantizing fp8 weights from bf16...")
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
output_device=device,
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
parameter.data = parameter.to(device=device)
return model
@ -136,6 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -143,9 +143,6 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
@ -287,16 +284,16 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
config: MetaReferenceQuantizedInferenceConfig,
) -> Transformer:
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
@ -316,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))

View file

@ -12,8 +12,7 @@
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from ..datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,
JsonCustomToolGenerator,

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path

View file

@ -4,19 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
@ -34,80 +28,141 @@ def is_json(s):
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
def parse_llama_tool_call_format(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
Parse tool calls in the format:
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
"""
# Strip outer brackets and whitespace
input_string = input_string.strip()
if not (input_string.startswith("[") and input_string.endswith("]")):
return None
content = input_string[1:-1].strip()
if not content:
return None
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# State variables for parsing
pos = 0
length = len(content)
# Extract keyword arguments
for keyword in node.keywords:
while pos < length:
# Find function name
name_end = content.find("(", pos)
if name_end == -1:
break
func_name = content[pos:name_end].strip()
# Find closing parenthesis for this function call
paren_level = 1
args_start = name_end + 1
args_end = args_start
while args_end < length and paren_level > 0:
if content[args_end] == "(":
paren_level += 1
elif content[args_end] == ")":
paren_level -= 1
args_end += 1
if paren_level != 0:
# Unmatched parentheses
return None
# Parse arguments
args_str = content[args_start : args_end - 1].strip()
args_dict = {}
if args_str:
# Split by commas, but respect nested structures
parts = []
part_start = 0
in_quotes = False
quote_char = None
nested_level = 0
for i, char in enumerate(args_str):
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
quote_char = None
elif not in_quotes:
if char in ("{", "["):
nested_level += 1
elif char in ("}", "]"):
nested_level -= 1
elif char == "," and nested_level == 0:
parts.append(args_str[part_start:i].strip())
part_start = i + 1
parts.append(args_str[part_start:].strip())
# Process each key=value pair
for part in parts:
if "=" in part:
key, value = part.split("=", 1)
key = key.strip()
value = value.strip()
# Try to convert value to appropriate Python type
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
# String
value = value[1:-1]
elif value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
elif value.startswith("{") and value.endswith("}"):
# This is a nested dictionary
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
elif value.startswith("[") and value.endswith("]"):
# This is a nested list
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
else:
# Try to convert to number
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if not a valid number
pass
result.append((function_name, function_args))
args_dict[key] = value
return result
result.append((func_name, args_dict))
# Move to the next function call
pos = args_end
# Skip the comma between function calls if present
if pos < length and content[pos] == ",":
pos += 1
return result if result else None
class ToolUtils:
@ -155,11 +210,11 @@ class ToolUtils:
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
elif function_calls := parse_llama_tool_call_format(message_body):
# FIXME: Enable multiple tool calls
return res[0]
return function_calls[0]
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
return None
@staticmethod

View file

@ -21,8 +21,7 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from ..prompt_format import (
from llama_stack.models.llama.prompt_format import (
# llama3_1_e2e_tool_call_dialog,
TextCompletionContent,
UseCase,

View file

@ -3,10 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from pathlib import Path

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
class LoRAArgs(BaseModel):
rank: int
scale: float
class MoEArgs(BaseModel):
num_experts: int = -1
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
auto_scale_F: bool = ( # noqa: N815
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
)
top_k: int = 1
interleave_moe_layer_step: int = 1
class Size(BaseModel):
height: int
width: int
class VisionArgs(BaseModel):
image_size: Size
patch_size: Size
# parameters for the encoder transformer
dim: int
n_layers: int
n_heads: int
mlp_ratio: float
output_dim: int
pixel_shuffle_ratio: float
class ModelArgs(BaseModel):
dim: int = -1
n_layers: int = -1
n_heads: int = -1
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_exp: Optional[float] = None
norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None
rope_theta: float = 500000
use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False
floor_scale: float = 8192.0
attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None
moe_args: Optional[MoEArgs] = None
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048
@model_validator(mode="after")
def validate(self) -> "ModelArgs":
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
assert self.n_heads % self.n_kv_heads == 0, (
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
if self.use_scaled_rope:
# NOTE: ideally these values should have come from params.json. However, we have
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
# specific values.
if self.rope_scaling_factor is None:
self.rope_scaling_factor = 16
if self.rope_high_freq_factor is None:
self.rope_high_freq_factor = 1
return self

View file

@ -0,0 +1,316 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(
self,
tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None,
max_num_chunks: int = 16,
):
self.tokenizer = tokenizer
self.vision_args = vision_args
self.max_num_chunks = max_num_chunks
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
)
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
# TODO: need to check if this is correct
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_image(
self,
transformed_image: TransformedImage,
) -> List[int]:
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
image_channels = image_tensor.shape[-3]
image_height = image_tensor.shape[-2]
image_width = image_tensor.shape[-1]
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
patch_height = self.vision_args.patch_size.height
patch_width = self.vision_args.patch_size.width
if image_height % patch_height != 0:
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
if image_width % patch_width != 0:
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
image_ar = transformed_image.aspect_ratio
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
if image_chunks == 1:
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
else:
ratio_h, ratio_w = image_ar
for _ in range(ratio_h):
for xx in range(ratio_w):
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
if xx < ratio_w - 1:
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
tokens = []
tranformed_images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
if not self.vision_args:
raise ValueError("The model is not vision-enabled, but a media item was found")
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
image_global = self.image_transform(image)
image_global = image_global.unsqueeze(0)
image_combine = torch.cat((image_tiles, image_global), dim=0)
image_tiles = image_combine
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
tokens.extend(self._encode_image(transformed_image))
tranformed_images.append(transformed_image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, tranformed_images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
_process_content(message.content)
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput:
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom|>"):
content = content[: -len("<|eom|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
return LLMInput(
tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None,
)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
@dataclass
class MaskedEmbedding:
embedding: torch.Tensor
mask: torch.Tensor
@dataclass
class LLMInput:
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None
@dataclass
class TransformerInput:
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
tokens: torch.Tensor
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int]
image_embedding: Optional[MaskedEmbedding] = None
@dataclass
class LLMOutput:
logits: torch.Tensor
TransformerOutput = LLMOutput

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import nn
from torch.nn import functional as F
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
do_reduce: bool = True,
):
super().__init__()
self.do_reduce = do_reduce
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
state_dict[prefix + "w1.weight"] = w1
state_dict[prefix + "w3.weight"] = w3
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
def forward(self, x):
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
out = F.linear(x, self.w2.weight)
if self.do_reduce:
return reduce_from_model_parallel_region(out)
return out

View file

@ -0,0 +1,313 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import codecs
import io
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class Llama4:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
**params,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
tokenizer = Tokenizer.get_instance()
# TODO: params.json should always have correct vocab_size
if model_args.vocab_size == -1:
model_args.vocab_size = tokenizer.n_words
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama4(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
cprint("Input to model:\n", "yellow")
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens), "grey")
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if echo:
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
)
xformer_input = TransformerInput(
tokens=tokens[:, prev_pos:cur_pos],
tokens_position=prev_pos,
image_embedding=image_embedding,
)
xformer_output = self.model.forward(xformer_input)
logits = xformer_output.logits
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,437 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
from .datatypes import TransformerInput, TransformerOutput
from .ffn import FeedForward
from .moe import MoE
def rmsnorm(x, eps):
def _norm(y):
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
return _norm(x.float()).type_as(x)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return rmsnorm(x, self.eps) * self.weight
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
low_freq_factor = 1
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
use_scaled: bool,
scale_factor: float,
high_freq_factor: float,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
# TODO: this module needs to be moved into a separate file since it can be used by
# the vision encoder as well.
def __init__(
self,
args: ModelArgs,
use_qk_norm: bool,
use_rope: bool,
add_bias: bool = False,
):
super().__init__()
self.use_rope = use_rope
self.use_qk_norm = use_qk_norm
# For attention temperature tuning
self.attn_temperature_tuning = args.attn_temperature_tuning
self.floor_scale = args.floor_scale
self.attn_scale = args.attn_scale
self.n_heads = args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=add_bias,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=add_bias,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.norm_eps = args.norm_eps
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
if r != 0:
raise ValueError(
f"shape={tuple(wqkv.shape)} is not divisible by "
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
)
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
if self.use_rope:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = rmsnorm(xq, self.norm_eps)
xk = rmsnorm(xk, self.norm_eps)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
if self.attn_temperature_tuning and not self.use_rope:
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
attn_scales = attn_scales.view(1, seqlen, 1, 1)
xq = xq * attn_scales
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
xk = self.cache_k[:bsz, : start_pos + seqlen]
xv = self.cache_v[:bsz, : start_pos + seqlen]
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
xk = xk.repeat_interleave(self.n_rep, dim=1)
xv = xv.repeat_interleave(self.n_rep, dim=1)
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(attn_output)
return output
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
use_rope = not self.is_nope_layer
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
self.feed_forward = MoE(
dim=args.dim,
hidden_dim=int(args.ffn_exp * args.dim),
ffn_dim_multiplier=args.ffn_dim_multiplier,
multiple_of=args.multiple_of,
moe_args=args.moe_args,
)
else:
hidden_dim = int(4 * args.dim)
hidden_dim = int(2 * hidden_dim / 3)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=hidden_dim,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
elif prefix + "feed_forward.norm.weight" in state_dict:
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
for k in (
"feed_forward.experts.mlp",
"feed_forward.mlp_shared",
"attention.wo",
"attention.wqkv",
):
if prefix + k + "._extra_state" in state_dict:
state_dict.pop(prefix + k + "._extra_state")
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
if self.is_nope_layer or local_attn_mask is None:
mask = global_attn_mask
else:
mask = local_attn_mask
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs, **kwargs) -> None:
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(TransformerBlock(layer_id, args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
args.dim // args.n_heads,
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
args.rope_scaling_factor,
args.rope_high_freq_factor,
)
vision_args = self.args.vision_args
if vision_args:
# circular import otherwise until we refactor out Attention
from .vision.embedding import VisionEmbeddings
self.vision_embeddings = VisionEmbeddings(vision_args)
self.vision_projection = ColumnParallelLinear(
vision_args.output_dim,
args.dim,
bias=False,
init_method=lambda x: x,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")
@torch.inference_mode()
def forward(self, model_input: TransformerInput) -> TransformerOutput:
tokens = model_input.tokens
start_pos = model_input.tokens_position
assert isinstance(start_pos, int), (
"This implementation does not support different start positions per batch item"
)
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if image_embedding := model_input.image_embedding:
h_image = self.vision_projection(image_embedding.embedding)
h = h * ~image_embedding.mask + h_image * image_embedding.mask
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
global_attn_mask, local_attn_mask = None, None
if seqlen > 1:
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if global_attn_mask.device.type == torch.device("mps").type:
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
if chunk_size := self.args.attention_chunk_size:
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
h = self.norm(h)
output = self.output(h).float()
return TransformerOutput(logits=output)
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
# in the iRoPE architecture
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
block_pos = torch.abs(
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
)
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
return mask.to(device)

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N806
# pyre-strict
from typing import Any, Dict, List
import fairscale.nn.model_parallel.initialize as fs_init
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torch.nn import functional as F
from .args import MoEArgs
from .ffn import FeedForward
class Experts(nn.Module):
def __init__(
self,
num_local_experts: int,
dim: int,
hidden_dim: int,
) -> None:
super().__init__()
dtype = torch.get_default_dtype()
self.num_local_experts = num_local_experts
self.dim = dim
divide_factor = fs_init.get_model_parallel_world_size()
self.w1: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self.w2: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
divide_exact(hidden_dim, divide_factor),
dim,
dtype=dtype,
)
)
self.w3: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict:
e = self.num_local_experts
D = self.dim
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
def forward(
self,
routed_in_egD: torch.Tensor, # noqa: N803
) -> torch.Tensor:
e = self.num_local_experts
D = self.dim
x_egD = routed_in_egD.view(e, -1, D)
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
out_egD = out_egD.view(-1, D)
return out_egD
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
return torch.bmm(middle_out_egF, w2)
class MoE(torch.nn.Module):
"""
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include:
- a: bsz*slen
- E: number of experts
- e: number of local experts per ep (n_experts/ep)
- D: hidden dimension
- d: D/tp
- F: model dimension
- G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP)
Examples:
x_aD [a, D]
routed_in_etG_D [et*G, D]
x_eGD: [e, G, D]
"""
def __init__(
self,
dim: int,
hidden_dim: int,
ffn_dim_multiplier: float,
multiple_of: int,
moe_args: MoEArgs,
) -> None:
super().__init__()
self.moe_args = moe_args
hidden_dim_denom: float = 1
if moe_args.auto_scale_F:
hidden_dim_denom = moe_args.capacity_factor + 1
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
if moe_args.auto_scale_F:
hidden_dim = int(hidden_dim / hidden_dim_denom)
hidden_dim += -hidden_dim % multiple_of
num_local_experts: int = moe_args.num_experts
dtype: torch.dtype = torch.get_default_dtype()
self.experts = Experts(
num_local_experts,
dim,
hidden_dim,
)
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
_, slen, D = x_bsD.shape
x_aD = x_bsD.view(-1, D)
a = x_aD.shape[0]
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
router_scores = (
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
.scatter_(1, router_indices_aK, router_scores_aK)
.transpose(0, 1)
)
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
router_scores = torch.sigmoid(router_scores)
routed_in_EG_D: Tensor = torch.gather(
x_aD,
dim=0,
index=router_indices.reshape(-1, 1).expand(-1, D),
)
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD)
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_(
dim=0,
index=router_indices_EG_D,
src=routed_out_eg_D.view(-1, D),
)
out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D)
def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator

View file

@ -0,0 +1,436 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from typing import Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image, ImageFile
from torchvision.transforms import functional as F
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_RES = 448
class ResizeNormalizeImageTransform:
def __init__(
self,
size_width=None,
size_height=None,
) -> None:
self._size_width = size_width or IMAGE_RES
self._size_height = size_height or IMAGE_RES
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.tv_transform = tv.Compose(
[
tv.Resize((self._size_height, self._size_width)),
tv.ToTensor(),
tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
),
]
)
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.tv_transform(image)
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
self.to_tensor = tv.ToTensor()
self._mean = (0.5, 0.5, 0.5)
self._std = (0.5, 0.5, 0.5)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, width in value:
possible_resolutions.append((height * patch_size, width * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(
max(new_size_without_distortion[1], 1),
max(new_size_without_distortion[0], 1),
),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from io import BytesIO
from pathlib import Path
from typing import List
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
from ..prompt_format import (
Llama4UseCase,
TextCompletionContent,
UseCase,
)
THIS_DIR = Path(__file__).parent
def usecases(base_model: bool = False) -> List[UseCase | str]:
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
img_small_dog = f.read()
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
img_dog = f.read()
with open(THIS_DIR.parent / "resources/pasta.jpeg", "rb") as f:
img_pasta = f.read()
out = []
out.extend(
[
textwrap.dedent(
"""
# Llama 4 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 4:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|header_start|>` and `<|header_end|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user and assistant].
- `<|eot|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|image_start|>` and `<|image_end|>`: These tokens enclose the image data in the prompt.
- `<|patch|>`: This token represents a piece of the tile/
- `<|tile_y_separator|>` and `<|tile_x_separator|>`: These tokens are used to separate the y and x tiles of an image
- `<|image|>`: In the new architecture, this token now separates the regular sized image information from a downsized version of it that fits in a single tile. The longer side is used for calculating the scale factor and the rest is padded to fit the tile.
"""
),
textwrap.dedent(
"""
There are 3 different roles that are supported by Llama 4
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
"""
),
]
)
if base_model:
out.extend(
[
"# Llama 4 Base Model",
Llama4UseCase(
title="Text completion - Paris information",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[TextCompletionContent(content="The capital of France is Paris")],
),
Llama4UseCase(
title="Text completion - The color of the sky",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be")
],
notes="",
),
Llama4UseCase(
title="Text completion - Translation example",
description="Text completion for Llama 4 base model uses this format.",
dialogs=[
TextCompletionContent(
content="""apple is pomme,
bannana is banane,
cherry is"""
)
],
notes="",
),
]
)
out.extend(
[
"# Llama 4 Instruct Model",
Llama4UseCase(
title="Simple User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(
role="user",
content="Answer who are you in the form of jeopardy?",
),
]
],
notes="",
max_gen_len=512,
),
"# Image prompt format",
Llama4UseCase(
title="Single image prompt format - small image",
description="This example passes an image that is smaller than the tile size, to show the tile separator tokens are not needed",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_small_dog)),
RawTextItem(text="Describe this image in two sentences"),
],
)
]
],
notes="""Notice the structure of the image section:
```
<|image_start|><|image|><|patch|>...<|patch|><|image_end|>
```
This is due to the image being smaller than the tile size.
""",
max_gen_len=512,
),
Llama4UseCase(
title="Single image prompt format",
description="Here is an example of how to pass an image to the model",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_dog)),
RawTextItem(text="Describe this image in two sentences"),
],
)
]
],
notes="""With a bigger image, the image will include the tile separator tokens. Additionally, the image tag now separates a scaled down version of the image from the regular sized image.
```
<|image_start|><|patch|>...<|patch|><|tile_x_separator|><|patch|>...<|patch|><|tile_y_separator|><|patch|>...<|patch|><|image|><|patch|>...<|patch|><|image_end|>
```
""",
max_gen_len=1024,
),
Llama4UseCase(
title="Multiple images prompt format",
description="Here is an example of how to pass an image to the model",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=BytesIO(img_dog)),
RawMediaItem(data=BytesIO(img_pasta)),
RawTextItem(text="Describe these images in two sentences"),
],
)
]
],
notes="With multiple images, each one is encapsulated in their corresponding image tags.",
max_gen_len=4096,
),
"# Tool calling\nWe are continuing the format for zero shot function calling used in previous versions of Llama. All available functions can be provided either in the system message or in the user message.",
Llama4UseCase(
title="Zero shot function calling - system message",
dialogs=[
[
RawMessage(
role="system",
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
""",
),
RawMessage(
role="user",
content="What is the weather in SF and Seattle?",
),
]
],
notes=textwrap.dedent(
"""
- The output supports multiple, and parallel tool calls natively
- JSON format for defining the functions in the system prompt is similar to Llama3.1
"""
),
),
Llama4UseCase(
title="Zero shot function calling - user message",
description=textwrap.dedent(
"""
Similar to the above example, you can also provide information for all the available tools in the user message.
"""
),
dialogs=[
[
RawMessage(
role="user",
content="""Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_user_info",
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
"parameters": {
"type": "dict",
"required": [
"user_id"
],
"properties": {
"user_id": {
"type": "integer",
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
},
"special": {
"type": "string",
"description": "Any special information or parameters that need to be considered while fetching user details.",
"default": "none"
}
}
}
}
]
Should you decide to return the function call(s), put them in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
You SHOULD NOT include any other text in the response.""",
),
]
],
notes=textwrap.dedent(
"""
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
"""
),
),
Llama4UseCase(
title="Tool calling with custom formats",
description=textwrap.dedent(
"""
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
"""
),
dialogs=[
[
RawMessage(
role="user",
content="""You have access to the following functions:\nUse the function 'trending_songs' to 'Returns the trending songs on a Music site':\n{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}\n\nThink very carefully before calling functions.\nIf you choose to call a function ONLY reply in the following format with no prefix or suffix:\n\n<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line<|eot_id|>""",
),
RawMessage(
role="user",
content="Use tools to get latest trending songs",
),
]
],
),
]
)
return out

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

Some files were not shown because too many files have changed in this diff Show more