Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-26 12:29:45 -07:00
commit bc0cd07008
79 changed files with 3257 additions and 2358 deletions

View file

@ -1,5 +1,76 @@
# Changelog
# v0.1.8
Published on: 2025-03-24T01:28:50Z
# v0.1.8 Release Notes
### Build and Test Agents
* Safety: Integrated NVIDIA as a safety provider.
* VectorDB: Added Qdrant as an inline provider.
* Agents: Added support for multiple tool groups in agents.
* Agents: Simplified imports for Agents in client package
### Agent Evals and Model Customization
* Introduced DocVQA and IfEval benchmarks.
### Deploying and Monitoring Agents
* Introduced a Containerfile and image workflow for the Playground.
* Implemented support for Bearer (API Key) authentication.
* Added attribute-based access control for resources.
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
### Better Engineering
* Consolidated scripts under the ./scripts directory.
* Addressed mypy violations in various modules.
* Added Dependabot scans for Python dependencies.
* Implemented a scheduled workflow to update the changelog automatically.
* Enforced concurrency to reduce CI loads.
### New Contributors
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
---
# v0.1.7
Published on: 2025-03-14T22:30:51Z
## 0.1.7 Release Notes
### Build and Test Agents
* Inference: ImageType is now refactored to LlamaStackImageType
* Inference: Added tests to measure TTFT
* Inference: Bring back usage metrics
* Agents: Added endpoint for get agent, list agents and list sessions
* Agents: Automated conversion of type hints in client tool for lite llm format
* Agents: Deprecated ToolResponseMessage in agent.resume API
* Added Provider API for listing and inspecting provider info
### Agent Evals and Model Customization
* Eval: Added new eval benchmarks Math 500 and BFCL v3
* Deploy and Monitoring of Agents
* Telemetry: Fix tracing to work across coroutines
### Better Engineering
* Display code coverage for unit tests
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
* Unit tests also run on Python 3.11, 3.12, and 3.13
* Added ollama inference to Integration tests CI
* Improved documentation across examples, testing, CLI, updated providers table )
---
# v0.1.6
Published on: 2025-03-08T04:35:08Z

View file

@ -81,7 +81,9 @@ Note that you can create a dotenv file `.env` that includes necessary environmen
LLAMA_STACK_BASE_URL=http://localhost:8321
LLAMA_STACK_CLIENT_LOG=debug
LLAMA_STACK_PORT=8321
LLAMA_STACK_CONFIG=
LLAMA_STACK_CONFIG=<provider-name>
TAVILY_SEARCH_API_KEY=
BRAVE_SEARCH_API_KEY=
```
And then use this dotenv file when running client SDK tests via the following:

View file

@ -367,6 +367,7 @@
"zmq"
],
"nvidia": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",

View file

@ -818,14 +818,7 @@
"delete": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/FileResponse"
}
}
}
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
@ -4002,22 +3995,33 @@
"type": "object",
"properties": {
"strategy": {
"$ref": "#/components/schemas/SamplingStrategy"
"$ref": "#/components/schemas/SamplingStrategy",
"description": "The sampling strategy."
},
"max_tokens": {
"type": "integer",
"default": 0
"default": 0,
"description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length."
},
"repetition_penalty": {
"type": "number",
"default": 1.0
"default": 1.0,
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics."
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
}
},
"additionalProperties": false,
"required": [
"strategy"
],
"title": "SamplingParams"
"title": "SamplingParams",
"description": "Sampling parameters."
},
"SamplingStrategy": {
"oneOf": [
@ -6078,46 +6082,6 @@
"title": "FileUploadResponse",
"description": "Response after initiating a file upload session."
},
"FileResponse": {
"type": "object",
"properties": {
"bucket": {
"type": "string",
"description": "Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)"
},
"key": {
"type": "string",
"description": "Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)"
},
"mime_type": {
"type": "string",
"description": "MIME type of the file"
},
"url": {
"type": "string",
"description": "Upload URL for the file contents"
},
"bytes": {
"type": "integer",
"description": "Size of the file in bytes"
},
"created_at": {
"type": "integer",
"description": "Timestamp of when the file was created"
}
},
"additionalProperties": false,
"required": [
"bucket",
"key",
"mime_type",
"url",
"bytes",
"created_at"
],
"title": "FileResponse",
"description": "Response representing a file entry."
},
"EmbeddingsRequest": {
"type": "object",
"properties": {
@ -6498,52 +6462,47 @@
"title": "URIDataSource",
"description": "A dataset that can be obtained from a URI."
},
"EqualityGrader": {
"FileResponse": {
"type": "object",
"properties": {
"type": {
"bucket": {
"type": "string",
"const": "equality",
"default": "equality"
"description": "Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)"
},
"key": {
"type": "string",
"description": "Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)"
},
"mime_type": {
"type": "string",
"description": "MIME type of the file"
},
"url": {
"type": "string",
"description": "Upload URL for the file contents"
},
"bytes": {
"type": "integer",
"description": "Size of the file in bytes"
},
"created_at": {
"type": "integer",
"description": "Timestamp of when the file was created"
}
},
"additionalProperties": false,
"required": [
"type"
"bucket",
"key",
"mime_type",
"url",
"bytes",
"created_at"
],
"title": "EqualityGrader"
"title": "FileResponse",
"description": "Response representing a file entry."
},
"FactualityGrader": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "factuality",
"default": "factuality"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "FactualityGrader"
},
"FaithfulnessGrader": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "faithfulness",
"default": "faithfulness"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "FaithfulnessGrader"
},
"Grader": {
"Model": {
"type": "object",
"properties": {
"identifier": {
@ -7895,6 +7854,31 @@
"title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset."
},
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled",
"cancelled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": {
"type": "object",
"properties": {

View file

@ -557,10 +557,6 @@ paths:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/FileResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -2764,16 +2760,33 @@ components:
properties:
strategy:
$ref: '#/components/schemas/SamplingStrategy'
description: The sampling strategy.
max_tokens:
type: integer
default: 0
description: >-
The maximum number of tokens that can be generated in the completion.
The token count of your prompt plus max_tokens cannot exceed the model's
context length.
repetition_penalty:
type: number
default: 1.0
description: >-
Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far, increasing the model's likelihood
to talk about new topics.
stop:
type: array
items:
type: string
description: >-
Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence.
additionalProperties: false
required:
- strategy
title: SamplingParams
description: Sampling parameters.
SamplingStrategy:
oneOf:
- $ref: '#/components/schemas/GreedySamplingStrategy'
@ -4246,39 +4259,6 @@ components:
title: FileUploadResponse
description: >-
Response after initiating a file upload session.
FileResponse:
type: object
properties:
bucket:
type: string
description: >-
Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
key:
type: string
description: >-
Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
mime_type:
type: string
description: MIME type of the file
url:
type: string
description: Upload URL for the file contents
bytes:
type: integer
description: Size of the file in bytes
created_at:
type: integer
description: Timestamp of when the file was created
additionalProperties: false
required:
- bucket
- key
- mime_type
- url
- bytes
- created_at
title: FileResponse
description: Response representing a file entry.
EmbeddingsRequest:
type: object
properties:
@ -4550,40 +4530,40 @@ components:
title: URIDataSource
description: >-
A dataset that can be obtained from a URI.
EqualityGrader:
FileResponse:
type: object
properties:
type:
bucket:
type: string
const: equality
default: equality
description: >-
Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
key:
type: string
description: >-
Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
mime_type:
type: string
description: MIME type of the file
url:
type: string
description: Upload URL for the file contents
bytes:
type: integer
description: Size of the file in bytes
created_at:
type: integer
description: Timestamp of when the file was created
additionalProperties: false
required:
- type
title: EqualityGrader
FactualityGrader:
type: object
properties:
type:
type: string
const: factuality
default: factuality
additionalProperties: false
required:
- type
title: FactualityGrader
FaithfulnessGrader:
type: object
properties:
type:
type: string
const: faithfulness
default: faithfulness
additionalProperties: false
required:
- type
title: FaithfulnessGrader
Grader:
- bucket
- key
- mime_type
- url
- bytes
- created_at
title: FileResponse
description: Response representing a file entry.
Model:
type: object
properties:
identifier:
@ -5453,6 +5433,25 @@ components:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
- cancelled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse:
type: object
properties:

File diff suppressed because one or more lines are too long

View file

@ -963,16 +963,19 @@
"\n",
"client.benchmarks.register(\n",
" benchmark_id=\"meta-reference::mmmu\",\n",
" # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the \n",
" # `input_rows` argument and does not fetch data from the dataset.\n",
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
" # Note: for the same reason as above, we can use any value as `scoring_functions`.\n",
" scoring_functions=[],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows_alpha(\n",
"response = client.eval.evaluate_rows(\n",
" benchmark_id=\"meta-reference::mmmu\",\n",
" input_rows=eval_rows,\n",
" # Note: Here we define the actual scoring functions.\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
" benchmark_config={\n",
" \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n",
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
@ -1139,12 +1142,11 @@
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows_alpha(\n",
"response = client.eval.evaluate_rows(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.data,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" benchmark_config={\n",
" \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n",
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
@ -1288,12 +1290,11 @@
" \"enable_session_persistence\": False,\n",
"}\n",
"\n",
"response = client.eval.evaluate_rows_alpha(\n",
"response = client.eval.evaluate_rows(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.data,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" benchmark_config={\n",
" \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n",
" \"type\": \"agent\",\n",
" \"config\": agent_config,\n",

View file

@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -68,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -1395,6 +1395,349 @@
"pprint(session_response.turns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 Improved RAG with Long Context\n",
"\n",
"- Instead of performing reteival tool, we send documents as attachments to the agent and let it use the entire document context. \n",
"- Note how that the model is able to understand the entire context from documentation and answers the question with better factuality with improved retrieval. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What precision formats does torchtune support?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What precision formats does torchtune support?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> Torchtune supports two precision formats: `fp32` <span style=\"font-weight: bold\">(</span>full-precision<span style=\"font-weight: bold\">)</span> and `bfloat16` <span style=\"font-weight: bold\">(</span>half-precision<span style=\"font-weight: bold\">)</span>. \n",
"The `bfloat16` format uses <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m Torchtune supports two precision formats: `fp32` \u001b[1m(\u001b[0mfull-precision\u001b[1m)\u001b[0m and `bfloat16` \u001b[1m(\u001b[0mhalf-precision\u001b[1m)\u001b[0m. \n",
"The `bfloat16` format uses \u001b[1;36m2\u001b[0m bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What does DoRA stand for in torchtune?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What does DoRA stand for in torchtune?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA <span style=\"font-weight: bold\">(</span>Low-Rank Adaptation<span style=\"font-weight: bold\">)</span> \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA \u001b[1m(\u001b[0mLow-Rank Adaptation\u001b[1m)\u001b[0m \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How does the CPUOffloadOptimizer reduce GPU memory usage?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How does the CPUOffloadOptimizer reduce GPU memory usage?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How do I ensure only LoRA parameters are trainable when fine-tuning?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How do I ensure only LoRA parameters are trainable when fine-tuning?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span>` and the `requires_grad` attribute of the other parameters to `<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">get_adapter_params</span><span style=\"font-weight: bold\">(</span>model<span style=\"font-weight: bold\">)</span>\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">set_trainable_params</span><span style=\"font-weight: bold\">(</span>model, lora_params<span style=\"font-weight: bold\">)</span>\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`\u001b[3;92mTrue\u001b[0m` and the `requires_grad` attribute of the other parameters to `\u001b[3;91mFalse\u001b[0m`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = \u001b[1;35mget_adapter_params\u001b[0m\u001b[1m(\u001b[0mmodel\u001b[1m)\u001b[0m\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"\u001b[1;35mset_trainable_params\u001b[0m\u001b[1m(\u001b[0mmodel, lora_params\u001b[1m)\u001b[0m\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"urls = [\n",
" \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n",
" \"llama3.rst\",\n",
" \"datasets.rst\",\n",
" \"qat_finetune.rst\",\n",
" \"lora_finetune.rst\",\n",
"]\n",
"\n",
"attachments = [\n",
" {\n",
" \"content\": {\n",
" \"uri\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" },\n",
" \"mime_type\": \"text/plain\",\n",
" }\n",
"\n",
" for i, url in enumerate(urls)\n",
"]\n",
"\n",
"rag_attachment_agent = Agent(\n",
" client,\n",
" model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant that can answer questions about the Torchtune project. Use context from attached documentation for Torchtune to answer questions.\",\n",
")\n",
"\n",
"for example in examples:\n",
" session_id = rag_attachment_agent.create_session(session_name=f\"rag_attachment_session_{uuid.uuid4()}\")\n",
" response = rag_attachment_agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": example[\"input_query\"]\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" documents=attachments,\n",
" stream=False\n",
" )\n",
" rich.print(f\"[bold cyan]Question:[/bold cyan] {example['input_query']}\")\n",
" rich.print(f\"[bold yellow]Agent Answer:[/bold yellow] {response.output_message.content}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringScoreResponse</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">results</span>=<span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'braintrust::factuality'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span><span style=\"font-weight: bold\">}}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` (full-precision) and `bfloat16` (half-precision).\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter offload_gradients=True.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params(lora_model)` and setting them as trainable with `set_trainable_params(lora_model, lora_params)`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">]</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mScoringScoreResponse\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mresults\u001b[0m=\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'braintrust::factuality'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1;36m0.6\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mfull-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m and `bfloat16` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mhalf-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter \u001b[0m\u001b[32moffload_gradients\u001b[0m\u001b[32m=\u001b[0m\u001b[32mTrue\u001b[0m\u001b[32m.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model\u001b[0m\u001b[32m)\u001b[0m\u001b[32m` and setting them as trainable with `set_trainable_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model, lora_params\u001b[0m\u001b[32m)\u001b[0m\u001b[32m`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.\"\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"eval_rows = []\n",
"for i, session_id in enumerate(rag_attachment_agent.sessions):\n",
" session_response = client.agents.session.retrieve(agent_id=rag_attachment_agent.agent_id, session_id=session_id)\n",
" for turn in session_response.turns:\n",
" eval_rows.append({\n",
" \"input_query\": examples[i][\"input_query\"],\n",
" \"expected_answer\": examples[i][\"expected_answer\"],\n",
" \"generated_answer\": turn.output_message.content,\n",
" })\n",
"\n",
"scoring_params = {\n",
" \"braintrust::factuality\": None,\n",
"}\n",
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows,\n",
" scoring_functions=scoring_params,\n",
")\n",
"pprint(scoring_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},

View file

@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
from .pyopenapi.options import Options # noqa: E402
from .pyopenapi.specification import Info, Server # noqa: E402
from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
from .pyopenapi.utility import Specification, validate_api # noqa: E402
def str_presenter(dumper, data):
@ -40,8 +40,7 @@ def main(output_dir: str):
raise ValueError(f"Directory {output_dir} does not exist")
# Validate API protocols before generating spec
print("Validating API method return types...")
return_type_errors = validate_api_method_return_types()
return_type_errors = validate_api()
if return_type_errors:
print("\nAPI Method Return Type Validation Errors:\n")
for error in return_type_errors:

View file

@ -7,10 +7,9 @@
import json
import typing
import inspect
import os
from pathlib import Path
from typing import TextIO
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.distribution.resolver import api_protocol_map
@ -125,29 +124,59 @@ def is_optional_type(type_: Any) -> bool:
return origin is Optional or (origin is Union and type(None) in args)
def validate_api_method_return_types() -> List[str]:
"""Validate that all API methods have proper return types."""
def _validate_api_method_return_type(method) -> str | None:
hints = get_type_hints(method)
if 'return' not in hints:
return "has no return type annotation"
return_type = hints['return']
if is_optional_type(return_type):
return "returns Optional type"
def _validate_api_delete_method_returns_none(method) -> str | None:
hints = get_type_hints(method)
if 'return' not in hints:
return "has no return type annotation"
return_type = hints['return']
if return_type is not None and return_type is not type(None):
return "does not return None"
_VALIDATORS = {
"GET": [
_validate_api_method_return_type,
],
"DELETE": [
_validate_api_delete_method_returns_none,
],
}
def _get_methods_by_type(protocol, method_type: str):
members = inspect.getmembers(protocol, predicate=inspect.isfunction)
return {
method_name: method
for method_name, method in members
if (webmethod := getattr(method, '__webmethod__', None))
if webmethod and webmethod.method == method_type
}
def validate_api() -> List[str]:
"""Validate the API protocols."""
errors = []
protocols = api_protocol_map()
for protocol_name, protocol in protocols.items():
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for method_name, method in methods:
if not hasattr(method, '__webmethod__'):
continue
# Only check GET methods
if method.__webmethod__.method != "GET":
continue
hints = get_type_hints(method)
if 'return' not in hints:
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
else:
return_type = hints['return']
if is_optional_type(return_type):
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
for target, validators in _VALIDATORS.items():
for protocol_name, protocol in protocols.items():
for validator in validators:
for method_name, method in _get_methods_by_type(protocol, target).items():
err = validator(method)
if err:
errors.append(f"Method {protocol_name}.{method_name} {err}")
return errors

View file

@ -8,7 +8,7 @@ First, create a local Kubernetes cluster via Kind:
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
```
Start vLLM server as a Kubernetes Pod and Service:
First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
```bash
cat <<EOF |kubectl apply -f -
@ -31,7 +31,12 @@ metadata:
type: Opaque
data:
token: $(HF_TOKEN)
---
```
Next, start the vLLM server as a Kubernetes Deployment and Service:
```bash
cat <<EOF |kubectl apply -f -
apiVersion: apps/v1
kind: Deployment
metadata:
@ -47,28 +52,23 @@ spec:
app.kubernetes.io/name: vllm
spec:
containers:
- name: llama-stack
image: $(VLLM_IMAGE)
command:
- bash
- -c
- |
MODEL="meta-llama/Llama-3.2-1B-Instruct"
MODEL_PATH=/app/model/$(basename $MODEL)
huggingface-cli login --token $HUGGING_FACE_HUB_TOKEN
huggingface-cli download $MODEL --local-dir $MODEL_PATH --cache-dir $MODEL_PATH
python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH --served-model-name $MODEL --port 8000
- name: vllm
image: vllm/vllm-openai:latest
command: ["/bin/sh", "-c"]
args: [
"vllm serve meta-llama/Llama-3.2-1B-Instruct"
]
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
ports:
- containerPort: 8000
volumeMounts:
- name: llama-storage
mountPath: /app/model
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
mountPath: /root/.cache/huggingface
volumes:
- name: llama-storage
persistentVolumeClaim:

View file

@ -8,6 +8,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
@ -19,6 +20,12 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured:
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)

View file

@ -32,6 +32,7 @@ class Api(Enum):
datasets = "datasets"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
# built-in API
inspect = "inspect"

View file

@ -164,7 +164,7 @@ class Files(Protocol):
self,
bucket: str,
key: str,
) -> FileResponse:
) -> None:
"""
Delete a file identified by a bucket and key.

View file

@ -43,7 +43,7 @@ class StackRun(Subcommand):
self.parser.add_argument(
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
default=None,
help="Name of the image to run. Defaults to the current conda environment",
)
self.parser.add_argument(

View file

@ -12,6 +12,7 @@ from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -74,6 +75,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.evaluation: Evaluation,
Api.files: Files,
}
@ -107,7 +109,9 @@ async def resolve_impls(
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
routing_table_apis = {
x.routing_table_api for x in builtin_automatically_routed_apis()
}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
@ -115,7 +119,9 @@ async def resolve_impls(
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
@ -180,17 +186,23 @@ def validate_and_prepare_providers(
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
)
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
@ -200,10 +212,14 @@ def validate_and_prepare_providers(
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
def validate_provider(
provider: Provider, api: Api, provider_registry: ProviderRegistry
):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
@ -278,7 +294,9 @@ async def instantiate_providers(
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
f"inner-{x.value}": {} for x in router_apis
}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -287,7 +305,9 @@ async def instantiate_providers(
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
@ -345,7 +365,9 @@ async def instantiate_provider(
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
raise AttributeError(
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
)
module = importlib.import_module(provider_spec.module)
args = []
@ -382,7 +404,10 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
@ -410,12 +435,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -13,6 +13,7 @@ LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
@ -69,22 +70,25 @@ while [[ $# -gt 0 ]]; do
;;
esac
done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
source "$env_path_or_name/bin/activate"
fi
;;
"conda")
if ! is_command_available conda; then

View file

@ -18,15 +18,19 @@ def preserve_contexts_async_generator(
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
# Restore context values before any await
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break

View file

@ -195,10 +195,22 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):

View file

@ -6,14 +6,12 @@
import copy
import json
import os
import re
import secrets
import string
import uuid
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import (
RAGDocument,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
@ -453,8 +450,16 @@ class ChatAgent(ShieldRunnerMixin):
stream: bool = False,
documents: Optional[List[Document]] = None,
) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
await self.handle_documents(session_id, documents, input_messages)
contexts = []
for document in documents:
raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text)
attached_context = "\n".join(contexts)
input_messages[-1].context = attached_context
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -829,7 +834,10 @@ class ChatAgent(ShieldRunnerMixin):
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
@ -880,144 +888,27 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.vector_db_id is None:
vector_db_id = f"vector_db_{session_id}"
# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else:
vector_db_id = session_info.vector_db_id
return vector_db_id
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(url)
resp = r.text
return resp
raise ValueError(f"Unexpected URL: {type(url)}")
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
contents = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
if isinstance(message.content, list):
message.content.extend(contents)
async def get_raw_document_text(document: Document) -> str:
if not document.mime_type.startswith("text/"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
raise ValueError(f"Unexpected document content type: {type(document.content)}")
def _interpret_content_as_attachment(

View file

@ -28,6 +28,11 @@ class TelemetryConfig(BaseModel):
default="http://localhost:4318/v1/metrics",
description="The OpenTelemetry collector endpoint URL for metrics",
)
service_name: str = Field(
# service name is always the same, use zero-width space to avoid clutter
default="",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
@ -47,6 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -67,8 +67,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
resource = Resource.create(
{
# service name is always the same, use zero-width space to avoid clutter
ResourceAttributes.SERVICE_NAME: "",
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import hashlib
import logging
import sqlite3
@ -29,6 +30,15 @@ def serialize_vector(vector: List[float]) -> bytes:
return struct.pack(f"{len(vector)}f", *vector)
def _create_sqlite_connection(db_path):
"""Create a SQLite connection with sqlite_vec extension loaded."""
connection = sqlite3.connect(db_path)
connection.enable_load_extension(True)
sqlite_vec.load(connection)
connection.enable_load_extension(False)
return connection
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -37,40 +47,56 @@ class SQLiteVecIndex(EmbeddingIndex):
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
"""
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
def __init__(self, dimension: int, db_path: str, bank_id: str):
self.dimension = dimension
self.connection = connection
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
@classmethod
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
instance = cls(dimension, connection, bank_id)
async def create(cls, dimension: int, db_path: str, bank_id: str):
instance = cls(dimension, db_path, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
cur = self.connection.cursor()
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
self.connection.commit()
def _init_tables():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
connection.commit()
finally:
cur.close()
connection.close()
async def delete(self):
cur = self.connection.cursor()
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
await asyncio.to_thread(_init_tables)
async def delete(self) -> None:
def _drop_tables():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_drop_tables)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
"""
@ -81,44 +107,55 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
cur = self.connection.cursor()
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
if isinstance(chunk.content, str)
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit()
def _execute_all_batch_inserts():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
logger.error(f"Error inserting into {self.vector_table}: {e}")
try:
# Start transaction a single transcation for all batches
cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
if isinstance(chunk.content, str)
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
serialize_vector(emb.tolist()),
)
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
connection.commit()
finally:
cur.close() # Ensure cursor is closed
except sqlite3.Error as e:
connection.rollback() # Rollback on failure
logger.error(f"Error inserting into {self.vector_table}: {e}")
raise
finally:
cur.close()
connection.close()
# Process all batches in a single thread
await asyncio.to_thread(_execute_all_batch_inserts)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
@ -127,18 +164,28 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
cur = self.connection.cursor()
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
cur.execute(query_sql, (emb_blob, k))
rows = cur.fetchall()
chunks = []
scores = []
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
cur.execute(query_sql, (emb_blob, k))
return cur.fetchall()
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_execute_query)
chunks, scores = [], []
for _id, chunk_json, distance in rows:
try:
chunk = Chunk.model_validate_json(chunk_json)
@ -163,63 +210,81 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
self.connection: Optional[sqlite3.Connection] = None
async def initialize(self) -> None:
# Open a connection to the SQLite database (the file is specified in the config).
self.connection = sqlite3.connect(self.config.db_path)
self.connection.enable_load_extension(True)
sqlite_vec.load(self.connection)
self.connection.enable_load_extension(False)
cur = self.connection.cursor()
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
self.connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
def _setup_connection():
# Open a connection to the SQLite database (the file is specified in the config).
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_setup_connection)
for row in rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None:
if self.connection:
self.connection.close()
self.connection = None
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
cur = self.connection.cursor()
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
self.connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
cur = self.connection.cursor()
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
self.connection.commit()
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
if vector_db_id not in self.cache:

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import ProviderSpec
def available_providers() -> list[ProviderSpec]:
return []

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> List[ProviderSpec]:
@ -22,4 +22,13 @@ def available_providers() -> List[ProviderSpec]:
Api.datasets,
],
),
remote_provider_spec(
api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
),
),
]

View file

@ -55,7 +55,7 @@ from .openai_utils import (
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
@ -134,7 +134,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if content_has_media(content):
raise NotImplementedError("Media is not supported")
await check_health(self._config) # this raises errors
# ToDo: check health of NeMo endpoints and enable this
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = convert_completion_request(
@ -236,7 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = await convert_chat_completion_request(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,138 @@
# NVIDIA Post-Training Provider for LlamaStack
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
## Features
- Supervised fine-tuning of Llama models
- LoRA fine-tuning support
- Job management and status tracking
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to Hosted NVIDIA NeMo Customizer service
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
### Setup
Build the NVIDIA environment:
```bash
llama stack build --template nvidia --image-type conda
```
### Basic Usage using the LlamaStack Python Client
### Create Customization Job
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Configure fine-tuning parameters
```python
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
```
#### Set up LoRA configuration
```python
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
```
#### Configure training data
```python
data_config = TrainingConfigDataConfig(
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
batch_size=16,
)
```
#### Configure optimizer
```python
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
```
#### Set up training configuration
```python
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
```
#### Start fine-tuning job
```python
training_job = client.post_training.supervised_fine_tune(
job_uuid="unique-job-id",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
```
### List all jobs
```python
jobs = client.post_training.job.list()
```
### Check job status
```python
job_status = client.post_training.job.status(job_uuid="your-job-id")
```
### Cancel a job
```python
client.post_training.job.cancel(job_uuid="your-job-id")
```
### Inference with the fine-tuned model
```python
response = client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=False,
model_id="test-example-model@v1",
sampling_params={
"max_tokens": 50,
},
)
print(response.content)
```

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import NvidiaPostTrainingConfig
async def get_adapter_impl(
config: NvidiaPostTrainingConfig,
_deps,
):
from .post_training import NvidiaPostTrainingAdapter
if not isinstance(config, NvidiaPostTrainingConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaPostTrainingAdapter(config)
return impl
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]

View file

@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
# TODO: add default values for all fields
class NvidiaPostTrainingConfig(BaseModel):
"""Configuration for NVIDIA Post Training implementation."""
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key.",
)
dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace.",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
description="The NVIDIA project ID.",
)
# ToDO: validate this, add default value
customizer_url: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
description="Base URL for the NeMo Customizer API",
)
timeout: int = Field(
default=300,
description="Timeout for the NVIDIA Post Training API",
)
max_retries: int = Field(
default=3,
description="Maximum number of retries for the NVIDIA Post Training API",
)
# ToDo: validate this
output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
}
class SFTLoRADefaultConfig(BaseModel):
"""NVIDIA-specific training configuration with default values."""
# ToDo: split into SFT and LoRA configs??
# General training parameters
n_epochs: int = 50
# NeMo customizer specific parameters
log_every_n_steps: Optional[int] = None
val_check_interval: float = 0.25
sequence_packing_enabled: bool = False
weight_decay: float = 0.01
lr: float = 0.0001
# SFT specific parameters
hidden_dropout: Optional[float] = None
attention_dropout: Optional[float] = None
ffn_dropout: Optional[float] = None
# LoRA default parameters
lora_adapter_dim: int = 8
lora_adapter_dropout: Optional[float] = None
lora_alpha: int = 16
# Data config
batch_size: int = 8
@classmethod
def sample_config(cls) -> Dict[str, Any]:
"""Return a sample configuration for NVIDIA training."""
return {
"n_epochs": 50,
"log_every_n_steps": 10,
"val_check_interval": 0.25,
"sequence_packing_enabled": False,
"weight_decay": 0.01,
"hidden_dropout": 0.1,
"attention_dropout": 0.1,
"lora_adapter_dim": 8,
"lora_alpha": 16,
"data_config": {
"dataset_id": "default",
"batch_size": 8,
},
"optimizer_config": {
"lr": 0.0001,
},
}

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
)
]
def get_model_entries() -> List[ProviderModelEntry]:
return _MODEL_ENTRIES

View file

@ -0,0 +1,439 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
if config.api_key:
self.headers["Authorization"] = f"Bearer {config.api_key}"
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
# TODO: filter by available models based on /config endpoint
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
self.customizer_url = config.customizer_url
if not self.customizer_url:
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
self.customizer_url = "http://nemo.test"
async def _make_request(
self,
method: str,
path: str,
headers: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Helper method to make HTTP requests to the Customizer API."""
url = f"{self.customizer_url}{path}"
request_headers = self.headers.copy()
if headers:
request_headers.update(headers)
# Add content-type header for JSON requests
if json and "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/json"
for _ in range(self.config.max_retries):
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: Dict[str, Any],
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
extra_json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
Assumptions:
- nemo microservice is running and endpoint is set in config.customizer_url
- dataset is registered separately in nemo datastore
- model checkpoint is downloaded as per nemo customizer requirements
Parameters:
training_config: TrainingConfig - Configuration for training
model: str - Model identifier
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
Environment Variables:
- NVIDIA_API_KEY: str - API key for the NVIDIA API
Default: None
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
Default: "default"
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
Default: "http://nemo.test"
- NVIDIA_PROJECT_ID: str - ID of the project
Default: "test-project"
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
Default: "test-example-model@v1"
Supported models:
- meta/llama-3.1-8b-instruct
Supported algorithm configs:
- LoRA, SFT
Supported Parameters:
- TrainingConfig:
- n_epochs: int - Number of epochs to train
Default: 50
- data_config: DataConfig - Configuration for the dataset
- optimizer_config: OptimizerConfig - Configuration for the optimizer
- dtype: str - Data type for training
not supported (users are informed via warnings)
- efficiency_config: EfficiencyConfig - Configuration for efficiency
not supported
- max_steps_per_epoch: int - Maximum number of steps per epoch
Default: 1000
## NeMo customizer specific parameters
- log_every_n_steps: int - Log every n steps
Default: None
- val_check_interval: float - Validation check interval
Default: 0.25
- sequence_packing_enabled: bool - Sequence packing enabled
Default: False
## NeMo customizer specific SFT parameters
- hidden_dropout: float - Hidden dropout
Default: None (0.0-1.0)
- attention_dropout: float - Attention dropout
Default: None (0.0-1.0)
- ffn_dropout: float - FFN dropout
Default: None (0.0-1.0)
- DataConfig:
- dataset_id: str - Dataset ID
- batch_size: int - Batch size
Default: 8
- OptimizerConfig:
- lr: float - Learning rate
Default: 0.0001
## NeMo customizer specific parameter
- weight_decay: float - Weight decay
Default: 0.01
- LoRA config:
## NeMo customizer specific LoRA parameters
- adapter_dim: int - Adapter dimension
Default: 8 (supports powers of 2)
- adapter_dropout: float - Adapter dropout
Default: None (0.0-1.0)
- alpha: int - Scaling factor for the LoRA update
Default: 16
Note:
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
User is informed about unsupported parameters via warnings.
"""
# Map model to nvidia model name
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters
unsupported_method_params = []
if checkpoint_dir:
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
if hyperparam_search_config:
unsupported_method_params.append("hyperparam_search_config")
if logger_config:
unsupported_method_params.append("logger_config")
if unsupported_method_params:
warnings.warn(
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
stacklevel=2,
)
# Define all supported parameters
supported_params = {
"training_config": {
"n_epochs",
"data_config",
"optimizer_config",
"log_every_n_steps",
"val_check_interval",
"sequence_packing_enabled",
"hidden_dropout",
"attention_dropout",
"ffn_dropout",
},
"data_config": {"dataset_id", "batch_size"},
"optimizer_config": {"lr", "weight_decay"},
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
}
# Validate all parameters at once
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
warn_unsupported_params(
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
)
output_model = self.config.output_model_dir
# Prepare base job configuration
job_config = {
"config": nvidia_model,
"dataset": {
"name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace,
},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
**{
k: v
for k, v in {
"epochs": training_config.get("n_epochs"),
"batch_size": training_config["data_config"].get("batch_size"),
"learning_rate": training_config["optimizer_config"].get("lr"),
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
"log_every_n_steps": training_config.get("log_every_n_steps"),
"val_check_interval": training_config.get("val_check_interval"),
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
}.items()
if v is not None
},
},
"project": self.config.project_id,
# TODO: ignored ownership, add it later
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
"output_model": output_model,
}
# Handle SFT-specific optional parameters
job_config["hyperparameters"]["sft"] = {
k: v
for k, v in {
"ffn_dropout": training_config.get("ffn_dropout"),
"hidden_dropout": training_config.get("hidden_dropout"),
"attention_dropout": training_config.get("attention_dropout"),
}.items()
if v is not None
}
# Remove the sft dictionary if it's empty
if not job_config["hyperparameters"]["sft"]:
job_config["hyperparameters"].pop("sft")
# Handle LoRA-specific configuration
if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v
for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
}.items()
if v is not None
}
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(
method="POST",
path="/v1/customization/jobs",
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from typing import Any, Dict, Set, Tuple
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
unsupported_params = [k for k in keys if k not in supported_keys]
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
)
def validate_training_params(
training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig"
) -> None:
"""
Validates training parameters against supported keys.
Args:
training_config: Dictionary containing training configuration parameters
supported_keys: Set of supported parameter keys
config_name: Name of the configuration for warning messages
"""
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
training_config_fields = set(TrainingConfig.__annotations__.keys())
# Check for not supported parameters:
# - not in either of configs
# - in TrainingConfig but not in SFTLoRADefaultConfig
unsupported_params = []
for key in training_config:
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
if key in (not sft_lora_fields or training_config_fields):
unsupported_params.append(key)
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
)
# ToDo: implement post health checks for customizer are enabled
async def _get_health(url: str) -> Tuple[bool, bool]: ...
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...

View file

@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
if params.stop is not None:
options["stop"] = params.stop
return options

View file

@ -37,6 +37,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
datasetio:

View file

@ -58,6 +58,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
tool_runtime:

View file

@ -40,6 +40,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
datasetio:

View file

@ -39,6 +39,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
datasetio:

View file

@ -69,6 +69,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
datasetio:

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
datasetio:

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
datasetio:

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
datasetio:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
datasetio:

View file

@ -44,6 +44,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
datasetio:

View file

@ -46,6 +46,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
datasetio:

View file

@ -17,8 +17,8 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
RunConfigSettings,
)
@ -29,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate:
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs"],
"tool_runtime": ["inline::rag-runtime"],
}
@ -87,7 +88,9 @@ def get_distribution_template() -> DistributionTemplate:
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_shields=[
ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")
],
default_tool_groups=default_tool_groups,
),
},
@ -96,6 +99,31 @@ def get_distribution_template() -> DistributionTemplate:
"",
"NVIDIA API Key",
),
## Nemo Customizer related variables
"NVIDIA_USER_ID": (
"llama-stack-user",
"NVIDIA User ID",
),
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_ACCESS_POLICIES": (
"{}",
"NVIDIA Access Policies",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
),
"NVIDIA_CUSTOMIZER_URL": (
"https://customizer.api.nvidia.com",
"NVIDIA Customizer URL",
),
"NVIDIA_OUTPUT_MODEL_DIR": (
"test-example-model@v1",
"NVIDIA Output Model Directory",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- inference
- post_training
- safety
- telemetry
- tool_runtime
@ -46,6 +47,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
datasetio:

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- inference
- post_training
- safety
- telemetry
- tool_runtime
@ -41,6 +42,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
datasetio:

View file

@ -41,6 +41,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
datasetio:

View file

@ -39,6 +39,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
datasetio:

View file

@ -0,0 +1,249 @@
version: '2'
image_name: open-benchmark
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:}
vector_io:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:localhost}
port: ${env.PGVECTOR_PORT:5432}
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
models:
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {}
dataset_id: ifeval
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/docvqa?split=val
metadata: {}
dataset_id: docvqa
scoring_fns: []
benchmarks:
- dataset_id: simpleqa
scoring_functions:
- llm-as-judge::405b-simpleqa
metadata: {}
benchmark_id: meta-reference-simpleqa
- dataset_id: mmlu_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-mmlu-cot
- dataset_id: gpqa_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
- dataset_id: bfcl
scoring_functions:
- basic::bfcl
metadata: {}
benchmark_id: meta-reference-bfcl
- dataset_id: ifeval
scoring_functions:
- basic::ifeval
metadata: {}
benchmark_id: meta-reference-ifeval
- dataset_id: docvqa
scoring_functions:
- basic::docvqa
metadata: {}
benchmark_id: meta-reference-docvqa
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
datasetio:

View file

@ -67,6 +67,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
tool_runtime:

View file

@ -60,6 +60,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
tool_runtime:

View file

@ -51,6 +51,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/sambanova/trace_store.db}
tool_runtime:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
datasetio:

View file

@ -42,6 +42,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
datasetio:

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
datasetio:

View file

@ -47,6 +47,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/vllm-gpu/trace_store.db}
datasetio:

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llama_stack"
version = "0.1.7"
version = "0.1.8"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -27,7 +27,7 @@ dependencies = [
"huggingface-hub",
"jinja2>=3.1.6",
"jsonschema",
"llama-stack-client>=0.1.7",
"llama-stack-client>=0.1.8",
"prompt-toolkit",
"python-dotenv",
"pydantic>=2",
@ -56,7 +56,7 @@ dev = [
"ruamel.yaml", # needed for openapi generator
]
# These are the dependencies required for running unit tests.
unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet", "qdrant-client"]
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
@ -64,6 +64,7 @@ unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet", "qdrant-client"
test = [
"openai",
"aiosqlite",
"aiohttp",
"torch>=2.6.0",
"torchvision>=0.21.0",
"opentelemetry-sdk",
@ -130,7 +131,6 @@ select = [
"F", # Pyflakes
"N", # Naming
"W", # Warnings
"I", # isort
"DTZ", # datetime rules
]
ignore = [
@ -152,6 +152,7 @@ ignore = [
[tool.mypy]
mypy_path = ["llama_stack"]
packages = ["llama_stack"]
plugins = ['pydantic.mypy']
disable_error_code = []
warn_return_any = true
# # honor excludes by not following there through imports
@ -262,6 +263,7 @@ exclude = [
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
"^llama_stack/providers/remote/post_training/nvidia/",
"^llama_stack/providers/remote/vector_io/chroma/",
"^llama_stack/providers/remote/vector_io/milvus/",
"^llama_stack/providers/remote/vector_io/pgvector/",
@ -303,3 +305,8 @@ exclude = [
# packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["yaml", "fire"]
ignore_missing_imports = true
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true

View file

@ -21,7 +21,7 @@ idna==3.10
jinja2==3.1.6
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
llama-stack-client==0.1.7
llama-stack-client==0.1.8
lxml==5.3.1
markdown-it-py==3.0.0
markupsafe==3.0.2

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
@pytest.mark.integration
@pytest.fixture(scope="session")
def post_training_provider_available(llama_stack_client):
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
return len(post_training_providers) > 0
@pytest.mark.integration
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
"""Check if post_training is in the api list.
This is a sanity check to ensure the provider is registered."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
assert len(post_training_providers) > 0
@pytest.mark.integration
def test_get_training_jobs(llama_stack_client, post_training_provider_available):
"""Test listing all training jobs."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
assert isinstance(jobs, dict)
assert "data" in jobs
assert isinstance(jobs["data"], list)
@pytest.mark.integration
def test_get_training_job_status(llama_stack_client, post_training_provider_available):
"""Test getting status of a specific training job."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
if not jobs["data"]:
pytest.skip("No training jobs available to check status")
job_uuid = jobs["data"][0]["job_uuid"]
job_status = llama_stack_client.post_training.get_training_job_status(job_uuid=job_uuid)
assert job_status is not None
assert "job_uuid" in job_status
assert "status" in job_status
assert job_status["job_uuid"] == job_uuid

View file

@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
assert len(content_str) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
"stop": ["1963"],
},
)
streamed_content = [chunk.delta for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert "1963" not in content_str
@pytest.mark.parametrize(
"test_case",
[

View file

@ -10,6 +10,11 @@
"expected": "1963"
}
},
"stop_sequence": {
"data": {
"content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
@pytest.fixture(scope="session", autouse=True)
def patch_aiohttp_session():
with patch("aiohttp.ClientSession", return_value=mock_session):
yield
@pytest.fixture
def event_loop():
"""Create and provide a new event loop for each test."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture
def run_async():
"""Fixture to run async functions in tests."""
def _run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return _run_async

View file

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
TrainingConfigEfficiencyConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value
else:
assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim,
adapter_dropout=0.2,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
}
)
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=training_config,
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,295 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import patch
import warnings
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJobStatusResponse,
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingJob,
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
if __name__ == "__main__":
unittest.main()

View file

@ -14,17 +14,10 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:
"""Test suite for testing provider configurations across all API types."""
def test_all_api_providers_exist(self):
provider_registry = get_provider_registry()
for api in providable_apis():
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
@pytest.mark.parametrize("api", providable_apis())
def test_api_providers(self, api):
provider_registry = get_provider_registry()
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
failures = []
for provider_type, provider_spec in providers.items():

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import pytest_asyncio
import sqlite_vec
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
generate_chunk_id,
)
@ -36,29 +35,25 @@ def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
@ -79,13 +74,14 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"

16
uv.lock generated
View file

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.10"
resolution-markers = [
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -1315,7 +1316,7 @@ wheels = [
[[package]]
name = "llama-stack"
version = "0.1.7"
version = "0.1.8"
source = { editable = "." }
dependencies = [
{ name = "blobfile" },
@ -1370,6 +1371,7 @@ docs = [
{ name = "tomli" },
]
test = [
{ name = "aiohttp" },
{ name = "aiosqlite" },
{ name = "autoevals" },
{ name = "chardet" },
@ -1385,6 +1387,7 @@ test = [
{ name = "torchvision", version = "0.21.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
unit = [
{ name = "aiohttp" },
{ name = "aiosqlite" },
{ name = "chardet" },
{ name = "openai" },
@ -1395,6 +1398,8 @@ unit = [
[package.metadata]
requires-dist = [
{ name = "aiohttp", marker = "extra == 'test'" },
{ name = "aiohttp", marker = "extra == 'unit'" },
{ name = "aiosqlite", marker = "extra == 'test'" },
{ name = "aiosqlite", marker = "extra == 'unit'" },
{ name = "autoevals", marker = "extra == 'test'" },
@ -1410,7 +1415,7 @@ requires-dist = [
{ name = "jinja2", specifier = ">=3.1.6" },
{ name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" },
{ name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.1.7" },
{ name = "llama-stack-client", specifier = ">=0.1.8" },
{ name = "mcp", marker = "extra == 'test'" },
{ name = "myst-parser", marker = "extra == 'docs'" },
{ name = "nbval", marker = "extra == 'dev'" },
@ -1455,10 +1460,11 @@ requires-dist = [
{ name = "types-setuptools", marker = "extra == 'dev'" },
{ name = "uvicorn", marker = "extra == 'dev'" },
]
provides-extras = ["dev", "unit", "test", "docs", "codegen"]
[[package]]
name = "llama-stack-client"
version = "0.1.7"
version = "0.1.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -1475,9 +1481,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/12/82/cd3ac4cddfeb2c63dc25b2469ded66e39101a01586bb2535bfae4293dc49/llama_stack_client-0.1.7.tar.gz", hash = "sha256:8ed51d73a62848a72e1c862d8881e64785be1a9118bcea94aa910025571b34d0", size = 242071 }
sdist = { url = "https://files.pythonhosted.org/packages/c5/cb/b058f62f52c7a29cc7238866cac357550bfc8ff99afcd1805436ef906181/llama_stack_client-0.1.8.tar.gz", hash = "sha256:2f1ff22c891c4c03e9b8fa18adbc0fb4d6e7337181d90a897760ec368ef3269c", size = 242834 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/dc/ee/055282258899f25f61d248a335e6bcdceb76e3f6ed76ad3a28240ebfa0fe/llama_stack_client-0.1.7-py3-none-any.whl", hash = "sha256:db89a21811310f9249a951a25f96298affa3ca1d3610a518093c94e984105292", size = 273882 },
{ url = "https://files.pythonhosted.org/packages/a3/aa/d02c314b93be7ce0e447a7d7e3cbbea1955e06651abdf6a273714058e4aa/llama_stack_client-0.1.8-py3-none-any.whl", hash = "sha256:e34e42ff491b3146130f6cf741fd5e901f8cd66302e8608e514655f74cc3df73", size = 274304 },
]
[[package]]