Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-17 17:00:30 -07:00
commit 66cd83fb58
37 changed files with 1215 additions and 840 deletions

View file

@ -1,9 +1,18 @@
name: Integration tests name: Integration tests
on: on:
pull_request:
push: push:
branches: [main] branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow
jobs: jobs:
ollama: ollama:
@ -56,8 +65,7 @@ jobs:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source .venv/bin/activate source .venv/bin/activate
# TODO: use "llama stack run" nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |

View file

@ -40,6 +40,7 @@ jobs:
matrix: matrix:
template: ${{ fromJson(needs.generate-matrix.outputs.templates) }} template: ${{ fromJson(needs.generate-matrix.outputs.templates) }}
image-type: [venv, container] image-type: [venv, container]
fail-fast: false # We want to run all jobs even if some fail
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -67,7 +68,9 @@ jobs:
- name: Run Llama Stack Build - name: Run Llama Stack Build
run: | run: |
uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test # USE_COPY_NOT_MOUNT is set to true since mounting is not supported by docker buildx, we use COPY instead
# LLAMA_STACK_DIR is set to the current directory so we are building from the source
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test
- name: Print dependencies in the image - name: Print dependencies in the image
if: matrix.image-type == 'venv' if: matrix.image-type == 'venv'

View file

@ -5,6 +5,14 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/unit/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/unit-tests.yml' # This workflow
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View file

@ -77,7 +77,7 @@ repos:
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.0 - uv==0.6.0
entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen entry: uv run --extra codegen ./scripts/distro_codegen.py
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

View file

@ -159,7 +159,7 @@ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama
### Updating Provider Configurations ### Updating Provider Configurations
If you have made changes to a provider's configuration in any form (introducing a new config key, or changing models, etc.), you should run `python llama_stack/scripts/distro_codegen.py` to re-generate various YAML files as well as the documentation. You should not change `docs/source/.../distributions/` files manually as they are auto-generated. If you have made changes to a provider's configuration in any form (introducing a new config key, or changing models, etc.), you should run `./scripts/distro_codegen.py` to re-generate various YAML files as well as the documentation. You should not change `docs/source/.../distributions/` files manually as they are auto-generated.
### Building the Documentation ### Building the Documentation

View file

@ -401,16 +401,13 @@
], ],
"nvidia": [ "nvidia": [
"aiosqlite", "aiosqlite",
"autoevals",
"blobfile", "blobfile",
"chardet", "chardet",
"datasets",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"matplotlib", "matplotlib",
"mcp",
"nltk", "nltk",
"numpy", "numpy",
"openai", "openai",

View file

@ -2233,6 +2233,67 @@
} }
}, },
"/v1/datasetio/iterrows/{dataset_id}": { "/v1/datasetio/iterrows/{dataset_id}": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/IterrowsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"DatasetIO"
],
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
"parameters": [
{
"name": "dataset_id",
"in": "path",
"description": "The ID of the dataset to get the rows from.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "start_index",
"in": "query",
"description": "Index into dataset for the first row to get. Get all rows if None.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of rows to get.",
"required": false,
"schema": {
"type": "integer"
}
}
]
}
},
"/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}": {
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
@ -6552,100 +6613,14 @@
"const": "factuality", "const": "factuality",
"default": "factuality" "default": "factuality"
}, },
"factuality": { "dataset_id": {
"type": "object",
"properties": {
"aggregation_functions": {
"type": "array",
"items": {
"type": "string",
"enum": [
"average",
"median",
"categorical_count",
"accuracy"
],
"title": "AggregationFunctionType",
"description": "A type of aggregation function."
}
}
},
"additionalProperties": false,
"required": [
"aggregation_functions"
],
"title": "BasicGraderParams"
}
},
"additionalProperties": false,
"required": [
"type",
"factuality"
],
"title": "FactualityGrader"
},
"FaithfulnessGrader": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "faithfulness",
"default": "faithfulness"
},
"faithfulness": {
"type": "object",
"properties": {
"aggregation_functions": {
"type": "array",
"items": {
"type": "string",
"enum": [
"average",
"median",
"categorical_count",
"accuracy"
],
"title": "AggregationFunctionType",
"description": "A type of aggregation function."
}
}
},
"additionalProperties": false,
"required": [
"aggregation_functions"
],
"title": "BasicGraderParams"
}
},
"additionalProperties": false,
"required": [
"type",
"faithfulness"
],
"title": "FaithfulnessGrader"
},
"Grader": {
"type": "object",
"properties": {
"identifier": {
"type": "string" "type": "string"
}, },
"provider_resource_id": { "scoring_functions": {
"type": "string" "type": "array",
}, "items": {
"provider_id": { "type": "string"
"type": "string" }
},
"type": {
"type": "string",
"const": "grader",
"default": "grader"
},
"grader": {
"$ref": "#/components/schemas/GraderDefinition"
},
"description": {
"type": "string"
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -6679,98 +6654,163 @@
"provider_resource_id", "provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"grader", "dataset_id",
"scoring_functions",
"metadata" "metadata"
], ],
"title": "Grader" "title": "Benchmark"
}, },
"GraderDefinition": { "DataSource": {
"oneOf": [ "oneOf": [
{ {
"$ref": "#/components/schemas/LlmGrader" "$ref": "#/components/schemas/URIDataSource"
}, },
{ {
"$ref": "#/components/schemas/RegexParserGrader" "$ref": "#/components/schemas/RowsDataSource"
},
{
"$ref": "#/components/schemas/EqualityGrader"
},
{
"$ref": "#/components/schemas/SubsetOfGrader"
},
{
"$ref": "#/components/schemas/FactualityGrader"
},
{
"$ref": "#/components/schemas/FaithfulnessGrader"
} }
], ],
"discriminator": { "discriminator": {
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"llm": "#/components/schemas/LlmGrader", "uri": "#/components/schemas/URIDataSource",
"regex_parser": "#/components/schemas/RegexParserGrader", "rows": "#/components/schemas/RowsDataSource"
"equality": "#/components/schemas/EqualityGrader",
"subset_of": "#/components/schemas/SubsetOfGrader",
"factuality": "#/components/schemas/FactualityGrader",
"faithfulness": "#/components/schemas/FaithfulnessGrader"
} }
} }
}, },
"LlmGrader": { "Grader": {
"type": "object",
"properties": {
"identifier": {
"type": "string"
},
"provider_resource_id": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"type": {
"type": "string",
"const": "grader",
"default": "grader"
},
"purpose": {
"type": "string",
"enum": [
"post-training/messages",
"eval/question-answer",
"eval/messages-answer"
],
"title": "DatasetPurpose",
"description": "Purpose of the dataset. Each purpose has a required input data schema."
},
"source": {
"$ref": "#/components/schemas/DataSource"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"identifier",
"provider_resource_id",
"provider_id",
"type",
"purpose",
"source",
"metadata"
],
"title": "Dataset"
},
"RowsDataSource": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"const": "llm", "const": "rows",
"default": "llm" "default": "rows"
}, },
"llm": { "rows": {
"type": "object", "type": "array",
"properties": { "items": {
"model": { "type": "object",
"type": "string" "additionalProperties": {
}, "oneOf": [
"prompt": { {
"type": "string" "type": "null"
}, },
"score_regexes": { {
"type": "array", "type": "boolean"
"items": { },
"type": "string" {
} "type": "number"
}, },
"aggregation_functions": { {
"type": "array", "type": "string"
"items": { },
"type": "string", {
"enum": [ "type": "array"
"average", },
"median", {
"categorical_count", "type": "object"
"accuracy" }
], ]
"title": "AggregationFunctionType",
"description": "A type of aggregation function."
}
} }
}, },
"additionalProperties": false, "description": "The dataset is stored in rows. E.g. - [ {\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}]} ]"
"required": [
"model",
"prompt",
"score_regexes",
"aggregation_functions"
],
"title": "LlmGraderParams"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"llm" "rows"
], ],
"title": "LlmGrader" "title": "RowsDataSource",
"description": "A dataset stored in rows."
},
"URIDataSource": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "uri",
"default": "uri"
},
"uri": {
"type": "string",
"description": "The dataset can be obtained from a URI. E.g. - \"https://mywebsite.com/mydata.jsonl\" - \"lsfs://mydata.jsonl\" - \"data:csv;base64,{base64_content}\""
}
},
"additionalProperties": false,
"required": [
"type",
"uri"
],
"title": "URIDataSource",
"description": "A dataset that can be obtained from a URI."
}, },
"RegexParserGrader": { "RegexParserGrader": {
"type": "object", "type": "object",
@ -6819,45 +6859,182 @@
], ],
"title": "RegexParserGrader" "title": "RegexParserGrader"
}, },
"SubsetOfGrader": { "ModelType": {
"type": "string",
"enum": [
"llm",
"embedding"
],
"title": "ModelType"
},
"AgentTurnInputType": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"const": "subset_of", "const": "agent_turn_input",
"default": "subset_of" "default": "agent_turn_input"
},
"subset_of": {
"type": "object",
"properties": {
"aggregation_functions": {
"type": "array",
"items": {
"type": "string",
"enum": [
"average",
"median",
"categorical_count",
"accuracy"
],
"title": "AggregationFunctionType",
"description": "A type of aggregation function."
}
}
},
"additionalProperties": false,
"required": [
"aggregation_functions"
],
"title": "BasicGraderParams"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type"
"subset_of"
], ],
"title": "SubsetOfGrader" "title": "AgentTurnInputType"
},
"ArrayType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "array",
"default": "array"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ArrayType"
},
"BooleanType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "boolean",
"default": "boolean"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "BooleanType"
},
"ChatCompletionInputType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "chat_completion_input",
"default": "chat_completion_input"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ChatCompletionInputType"
},
"CompletionInputType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "completion_input",
"default": "completion_input"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "CompletionInputType"
},
"JsonType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json",
"default": "json"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "JsonType"
},
"NumberType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "number",
"default": "number"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "NumberType"
},
"ObjectType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "object",
"default": "object"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ObjectType"
},
"ParamType": {
"oneOf": [
{
"$ref": "#/components/schemas/StringType"
},
{
"$ref": "#/components/schemas/NumberType"
},
{
"$ref": "#/components/schemas/BooleanType"
},
{
"$ref": "#/components/schemas/ArrayType"
},
{
"$ref": "#/components/schemas/ObjectType"
},
{
"$ref": "#/components/schemas/JsonType"
},
{
"$ref": "#/components/schemas/UnionType"
},
{
"$ref": "#/components/schemas/ChatCompletionInputType"
},
{
"$ref": "#/components/schemas/CompletionInputType"
},
{
"$ref": "#/components/schemas/AgentTurnInputType"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"string": "#/components/schemas/StringType",
"number": "#/components/schemas/NumberType",
"boolean": "#/components/schemas/BooleanType",
"array": "#/components/schemas/ArrayType",
"object": "#/components/schemas/ObjectType",
"json": "#/components/schemas/JsonType",
"union": "#/components/schemas/UnionType",
"chat_completion_input": "#/components/schemas/ChatCompletionInputType",
"completion_input": "#/components/schemas/CompletionInputType",
"agent_turn_input": "#/components/schemas/AgentTurnInputType"
}
}
}, },
"Model": { "Model": {
"type": "object", "type": "object",
@ -6913,17 +7090,39 @@
"provider_id", "provider_id",
"type", "type",
"metadata", "metadata",
"model_type" "return_type"
], ],
"title": "Model" "title": "ScoringFn"
}, },
"ModelType": { "StringType": {
"type": "string", "type": "object",
"enum": [ "properties": {
"llm", "type": {
"embedding" "type": "string",
"const": "string",
"default": "string"
}
},
"additionalProperties": false,
"required": [
"type"
], ],
"title": "ModelType" "title": "StringType"
},
"UnionType": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "union",
"default": "union"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "UnionType"
}, },
"Shield": { "Shield": {
"type": "object", "type": "object",
@ -8131,7 +8330,7 @@
}, },
"description": "The rows in the current page." "description": "The rows in the current page."
}, },
"next_index": { "next_start_index": {
"type": "integer", "type": "integer",
"description": "Index into dataset for the first row in the next page. None if there are no more rows." "description": "Index into dataset for the first row in the next page. None if there are no more rows."
} }
@ -9440,7 +9639,7 @@
}, },
"source": { "source": {
"$ref": "#/components/schemas/DataSource", "$ref": "#/components/schemas/DataSource",
"description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"data:csv;base64,{base64_content}\" } - { \"type\": \"uri\", \"uri\": \"huggingface://llamastack/simpleqa?split=train\" } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }" "description": "The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"data:csv;base64,{base64_content}\" } - { \"type\": \"uri\", \"uri\": \"huggingface://llamastack/simpleqa?split=train\" } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }"
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -9478,50 +9677,6 @@
"purpose", "purpose",
"source" "source"
], ],
"title": "RegisterDatasetRequest"
},
"RegisterGraderRequest": {
"type": "object",
"properties": {
"grader": {
"$ref": "#/components/schemas/GraderDefinition",
"description": "The grader definition, E.g. - { \"type\": \"llm\", \"llm\": { \"model\": \"llama-405b\", \"prompt\": \"You are a judge. Score the answer based on the question. {question} {answer}\", } }"
},
"grader_id": {
"type": "string",
"description": "(Optional) The ID of the grader. If not provided, a random ID will be generated."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "(Optional) Any additional metadata for this grader. - E.g. { \"description\": \"A grader that scores the answer based on the question.\", }"
}
},
"additionalProperties": false,
"required": [
"grader"
],
"title": "RegisterGraderRequest" "title": "RegisterGraderRequest"
}, },
"RegisterModelRequest": { "RegisterModelRequest": {
@ -10199,9 +10354,6 @@
{ {
"name": "Files" "name": "Files"
}, },
{
"name": "Graders"
},
{ {
"name": "Inference", "name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
@ -10254,9 +10406,8 @@
"Benchmarks", "Benchmarks",
"DatasetIO", "DatasetIO",
"Datasets", "Datasets",
"Evaluation", "Eval",
"Files", "Files",
"Graders",
"Inference", "Inference",
"Inspect", "Inspect",
"Models", "Models",

View file

@ -1507,6 +1507,50 @@ paths:
$ref: '#/components/schemas/InvokeToolRequest' $ref: '#/components/schemas/InvokeToolRequest'
required: true required: true
/v1/datasetio/iterrows/{dataset_id}: /v1/datasetio/iterrows/{dataset_id}:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/IterrowsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- DatasetIO
description: >-
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
parameters:
- name: dataset_id
in: path
description: >-
The ID of the dataset to get the rows from.
required: true
schema:
type: string
- name: start_index
in: query
description: >-
Index into dataset for the first row to get. Get all rows if None.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of rows to get.
required: false
schema:
type: integer
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get: get:
responses: responses:
'200': '200':
@ -4527,255 +4571,6 @@ components:
title: URIDataSource title: URIDataSource
description: >- description: >-
A dataset that can be obtained from a URI. A dataset that can be obtained from a URI.
EqualityGrader:
type: object
properties:
type:
type: string
const: equality
default: equality
equality:
type: object
properties:
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- aggregation_functions
title: BasicGraderParams
additionalProperties: false
required:
- type
- equality
title: EqualityGrader
FactualityGrader:
type: object
properties:
type:
type: string
const: factuality
default: factuality
factuality:
type: object
properties:
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- aggregation_functions
title: BasicGraderParams
additionalProperties: false
required:
- type
- factuality
title: FactualityGrader
FaithfulnessGrader:
type: object
properties:
type:
type: string
const: faithfulness
default: faithfulness
faithfulness:
type: object
properties:
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- aggregation_functions
title: BasicGraderParams
additionalProperties: false
required:
- type
- faithfulness
title: FaithfulnessGrader
Grader:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: grader
default: grader
grader:
$ref: '#/components/schemas/GraderDefinition'
description:
type: string
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- grader
- metadata
title: Grader
GraderDefinition:
oneOf:
- $ref: '#/components/schemas/LlmGrader'
- $ref: '#/components/schemas/RegexParserGrader'
- $ref: '#/components/schemas/EqualityGrader'
- $ref: '#/components/schemas/SubsetOfGrader'
- $ref: '#/components/schemas/FactualityGrader'
- $ref: '#/components/schemas/FaithfulnessGrader'
discriminator:
propertyName: type
mapping:
llm: '#/components/schemas/LlmGrader'
regex_parser: '#/components/schemas/RegexParserGrader'
equality: '#/components/schemas/EqualityGrader'
subset_of: '#/components/schemas/SubsetOfGrader'
factuality: '#/components/schemas/FactualityGrader'
faithfulness: '#/components/schemas/FaithfulnessGrader'
LlmGrader:
type: object
properties:
type:
type: string
const: llm
default: llm
llm:
type: object
properties:
model:
type: string
prompt:
type: string
score_regexes:
type: array
items:
type: string
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- model
- prompt
- score_regexes
- aggregation_functions
title: LlmGraderParams
additionalProperties: false
required:
- type
- llm
title: LlmGrader
RegexParserGrader:
type: object
properties:
type:
type: string
const: regex_parser
default: regex_parser
regex_parser:
type: object
properties:
parsing_regexes:
type: array
items:
type: string
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- parsing_regexes
- aggregation_functions
title: RegexParserGraderParams
additionalProperties: false
required:
- type
- regex_parser
title: RegexParserGrader
SubsetOfGrader:
type: object
properties:
type:
type: string
const: subset_of
default: subset_of
subset_of:
type: object
properties:
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- aggregation_functions
title: BasicGraderParams
additionalProperties: false
required:
- type
- subset_of
title: SubsetOfGrader
Model: Model:
type: object type: object
properties: properties:
@ -4817,6 +4612,224 @@ components:
- llm - llm
- embedding - embedding
title: ModelType title: ModelType
AgentTurnInputType:
type: object
properties:
type:
type: string
const: agent_turn_input
default: agent_turn_input
additionalProperties: false
required:
- type
title: AgentTurnInputType
ArrayType:
type: object
properties:
type:
type: string
const: array
default: array
additionalProperties: false
required:
- type
title: ArrayType
BooleanType:
type: object
properties:
type:
type: string
const: boolean
default: boolean
additionalProperties: false
required:
- type
title: BooleanType
ChatCompletionInputType:
type: object
properties:
type:
type: string
const: chat_completion_input
default: chat_completion_input
additionalProperties: false
required:
- type
title: ChatCompletionInputType
CompletionInputType:
type: object
properties:
type:
type: string
const: completion_input
default: completion_input
additionalProperties: false
required:
- type
title: CompletionInputType
JsonType:
type: object
properties:
type:
type: string
const: rows
default: rows
rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user",
"content": "Hello, world!"}, {"role": "assistant", "content": "Hello,
world!"}]} ]
additionalProperties: false
required:
- type
- rows
title: RowsDataSource
description: A dataset stored in rows.
URIDataSource:
type: object
properties:
type:
type: string
const: uri
default: uri
uri:
type: string
description: >-
The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}"
additionalProperties: false
required:
- type
- uri
title: URIDataSource
description: >-
A dataset that can be obtained from a URI.
EqualityGrader:
type: object
properties:
type:
type: string
const: equality
default: equality
equality:
type: object
properties:
aggregation_functions:
type: array
items:
type: string
enum:
- average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
additionalProperties: false
required:
- aggregation_functions
title: BasicGraderParams
additionalProperties: false
required:
- type
title: ObjectType
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
- $ref: '#/components/schemas/AgentTurnInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
agent_turn_input: '#/components/schemas/AgentTurnInputType'
ScoringFn:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: scoring_function
default: scoring_function
description:
type: string
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
return_type:
$ref: '#/components/schemas/ParamType'
params:
$ref: '#/components/schemas/ScoringFnParams'
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- grader
- metadata
- return_type
title: ScoringFn
StringType:
type: object
properties:
type:
type: string
const: string
default: string
additionalProperties: false
required:
- type
title: StringType
UnionType:
type: object
properties:
type:
type: string
const: union
default: union
additionalProperties: false
required:
- type
title: UnionType
Shield: Shield:
type: object type: object
properties: properties:
@ -5580,7 +5593,7 @@ components:
- type: array - type: array
- type: object - type: object
description: The rows in the current page. description: The rows in the current page.
next_index: next_start_index:
type: integer type: integer
description: >- description: >-
Index into dataset for the first row in the next page. None if there are Index into dataset for the first row in the next page. None if there are
@ -6461,12 +6474,14 @@ components:
source: source:
$ref: '#/components/schemas/DataSource' $ref: '#/components/schemas/DataSource'
description: >- description: >-
The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" The data source of the dataset. Ensure that the data source schema is
} - { "type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
"rows": [ { "messages": [ {"role": "user", "content": "Hello, world!"}, } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
{"role": "assistant", "content": "Hello, world!"}, ] } ] } } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -6488,37 +6503,6 @@ components:
- purpose - purpose
- source - source
title: RegisterDatasetRequest title: RegisterDatasetRequest
RegisterGraderRequest:
type: object
properties:
grader:
$ref: '#/components/schemas/GraderDefinition'
description: >-
The grader definition, E.g. - { "type": "llm", "llm": { "model": "llama-405b",
"prompt": "You are a judge. Score the answer based on the question. {question}
{answer}", } }
grader_id:
type: string
description: >-
(Optional) The ID of the grader. If not provided, a random ID will be
generated.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
(Optional) Any additional metadata for this grader. - E.g. { "description":
"A grader that scores the answer based on the question.", }
additionalProperties: false
required:
- grader
title: RegisterGraderRequest
RegisterModelRequest: RegisterModelRequest:
type: object type: object
properties: properties:
@ -6951,9 +6935,10 @@ tags:
- name: Benchmarks - name: Benchmarks
- name: DatasetIO - name: DatasetIO
- name: Datasets - name: Datasets
- name: Evaluation - name: Eval
x-displayName: >-
Llama Stack Evaluation API for running evaluations on model and agent candidates.
- name: Files - name: Files
- name: Graders
- name: Inference - name: Inference
description: >- description: >-
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Two kinds of models
@ -6988,9 +6973,8 @@ x-tagGroups:
- Benchmarks - Benchmarks
- DatasetIO - DatasetIO
- Datasets - Datasets
- Evaluation - Eval
- Files - Files
- Graders
- Inference - Inference
- Inspect - Inspect
- Models - Models

View file

@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.) - Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally. - Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary. - Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation. - Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
Here are some example PRs to help you get started: Here are some example PRs to help you get started:

View file

@ -6,13 +6,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| API | Provider(s) | | API | Provider(s) |
|-----|-------------| |-----|-------------|
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::nvidia` | | inference | `remote::nvidia` |
| safety | `inline::llama-guard` | | safety | `remote::nvidia` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` | | vector_io | `inline::faiss` |
@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models ### Models

View file

@ -6,17 +6,32 @@ The `llama-stack-client` CLI allows you to query information about the distribut
### `llama-stack-client` ### `llama-stack-client`
```bash ```bash
llama-stack-client -h llama-stack-client
Usage: llama-stack-client [OPTIONS] COMMAND [ARGS]...
usage: llama-stack-client [-h] {models,memory_banks,shields} ... Welcome to the LlamaStackClient CLI
Welcome to the LlamaStackClient CLI Options:
--version Show the version and exit.
--endpoint TEXT Llama Stack distribution endpoint
--api-key TEXT Llama Stack distribution API key
--config TEXT Path to config file
--help Show this message and exit.
options: Commands:
-h, --help show this help message and exit configure Configure Llama Stack Client CLI.
datasets Manage datasets.
subcommands: eval Run evaluation tasks.
{models,memory_banks,shields} eval_tasks Manage evaluation tasks.
inference Inference (chat).
inspect Inspect server configuration.
models Manage GenAI models.
post_training Post-training.
providers Manage API providers.
scoring_functions Manage scoring functions.
shields Manage safety shield services.
toolgroups Manage available tool groups.
vector_dbs Manage vector databases.
``` ```
### `llama-stack-client configure` ### `llama-stack-client configure`
@ -127,11 +142,11 @@ llama-stack-client vector_dbs list
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>] llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
``` ```
Options: Optional arguments:
- `--provider-id`: Optional. Provider ID for the vector db - `--provider-id`: Provider ID for the vector db
- `--provider-vector-db-id`: Optional. Provider's vector db ID - `--provider-vector-db-id`: Provider's vector db ID
- `--embedding-model`: Optional. Embedding model to use. Default: "all-MiniLM-L6-v2" - `--embedding-model`: Embedding model to use. Default: "all-MiniLM-L6-v2"
- `--embedding-dimension`: Optional. Dimension of embeddings. Default: 384 - `--embedding-dimension`: Dimension of embeddings. Default: 384
### `llama-stack-client vector_dbs unregister` ### `llama-stack-client vector_dbs unregister`
```bash ```bash
@ -157,11 +172,13 @@ llama-stack-client shields list
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>] llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
``` ```
Options: Required arguments:
- `--shield-id`: Required. ID of the shield - `--shield-id`: ID of the shield
- `--provider-id`: Optional. Provider ID for the shield
- `--provider-shield-id`: Optional. Provider's shield ID Optional arguments:
- `--params`: Optional. JSON configuration parameters for the shield - `--provider-id`: Provider ID for the shield
- `--provider-shield-id`: Provider's shield ID
- `--params`: JSON configuration parameters for the shield
## Eval Task Management ## Eval Task Management
@ -175,13 +192,15 @@ llama-stack-client benchmarks list
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>] llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
``` ```
Options: Required arguments:
- `--eval-task-id`: Required. ID of the eval task - `--eval-task-id`: ID of the eval task
- `--dataset-id`: Required. ID of the dataset to evaluate - `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: Required. One or more scoring functions to use for evaluation - `--scoring-functions`: One or more scoring functions to use for evaluation
- `--provider-id`: Optional. Provider ID for the eval task
- `--provider-eval-task-id`: Optional. Provider's eval task ID Optional arguments:
- `--metadata`: Optional. Metadata for the eval task in JSON format - `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
- `--metadata`: Metadata for the eval task in JSON format
## Eval execution ## Eval execution
### `llama-stack-client eval run-benchmark` ### `llama-stack-client eval run-benchmark`
@ -189,11 +208,13 @@ Options:
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
``` ```
Options: Required arguments:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format - `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where evaluation results will be saved - `--output-dir`: Path to the directory where evaluation results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes evaluation results after completion
Example benchmark_config.json: Example benchmark_config.json:
```json ```json
@ -214,11 +235,13 @@ Example benchmark_config.json:
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
``` ```
Options: Required arguments:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format - `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where scoring results will be saved - `--output-dir`: Path to the directory where scoring results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes scoring results after completion Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes scoring results after completion
## Tool Group Management ## Tool Group Management
@ -230,11 +253,11 @@ llama-stack-client toolgroups list
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| identifier | provider_id | args | mcp_endpoint | | identifier | provider_id | args | mcp_endpoint |
+===========================+==================+======+===============+ +===========================+==================+======+===============+
| builtin::code_interpreter | code-interpreter | None | None | | builtin::code_interpreter | code-interpreter | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| builtin::rag | rag-runtime | None | None | | builtin::rag | rag-runtime | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| builtin::websearch | tavily-search | None | None | | builtin::websearch | tavily-search | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
``` ```
@ -250,11 +273,11 @@ Shows detailed information about a specific toolgroup. If the toolgroup is not f
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>] llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
``` ```
Options: Optional arguments:
- `--provider-id`: Optional. Provider ID for the toolgroup - `--provider-id`: Provider ID for the toolgroup
- `--provider-toolgroup-id`: Optional. Provider's toolgroup ID - `--provider-toolgroup-id`: Provider's toolgroup ID
- `--mcp-config`: Optional. JSON configuration for the MCP endpoint - `--mcp-config`: JSON configuration for the MCP endpoint
- `--args`: Optional. JSON arguments for the toolgroup - `--args`: JSON arguments for the toolgroup
### `llama-stack-client toolgroups unregister` ### `llama-stack-client toolgroups unregister`
```bash ```bash

View file

@ -18,11 +18,11 @@ class IterrowsResponse(BaseModel):
A paginated list of rows from a dataset. A paginated list of rows from a dataset.
:param data: The rows in the current page. :param data: The rows in the current page.
:param next_index: Index into dataset for the first row in the next page. None if there are no more rows. :param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
""" """
data: List[Dict[str, Any]] data: List[Dict[str, Any]]
next_index: Optional[int] = None next_start_index: Optional[int] = None
class DatasetStore(Protocol): class DatasetStore(Protocol):
@ -46,9 +46,11 @@ class DatasetIO(Protocol):
:param dataset_id: The ID of the dataset to get the rows from. :param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None. :param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get per page. :param limit: The number of rows to get.
""" """
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -163,7 +163,7 @@ class Datasets(Protocol):
], ],
"answer": "John Doe" "answer": "John Doe"
} }
:param source: The data source of the dataset. Examples: :param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- { - {
"type": "uri", "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" "uri": "https://mywebsite.com/mydata.jsonl"

View file

@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -213,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args) run_command(run_args)
def _generate_run_config( def _generate_run_config(

View file

@ -82,7 +82,7 @@ class StackRun(Subcommand):
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config) config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml") has_yaml_suffix = args.config.endswith(".yaml")
@ -136,4 +136,4 @@ class StackRun(Subcommand):
if args.tls_keyfile and args.tls_certfile: if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile]) run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args) run_command(run_args)

View file

@ -6,7 +6,6 @@
import importlib.resources import importlib.resources
import logging import logging
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
@ -15,7 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -123,11 +122,7 @@ def build_image(
if special_deps: if special_deps:
args.append("#".join(special_deps)) args.append("#".join(special_deps))
is_terminal = sys.stdin.isatty() return_code = run_command(args)
if is_terminal:
return_code = run_with_pty(args)
else:
return_code = run_command(args)
if return_code != 0: if return_code != 0:
log.error( log.error(

View file

@ -43,7 +43,7 @@ RED='\033[0;31m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-} CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
@ -253,8 +253,7 @@ $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \ "${CLI_ARGS[@]}" \
-t "$image_tag" \ -t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \ -f "$TEMP_DIR/Containerfile" \
"." \ "."
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -8,10 +8,13 @@
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from .datatypes import StackRunConfig from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields from .stack import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel): class ProviderImplConfig(BaseModel):
run_config: StackRunConfig run_config: StackRunConfig
@ -31,6 +34,10 @@ class ProviderImpl(Providers):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import errno
import logging import logging
import os import os
import select
import signal import signal
import subprocess import subprocess
import sys
from termcolor import cprint from termcolor import cprint
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return run_args return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
return _run_with_pty_win(command)
else:
return _run_with_pty_unix(command)
def in_notebook(): def in_notebook():
try: try:
from IPython import get_ipython from IPython import get_ipython
@ -108,19 +98,19 @@ def in_notebook():
return True return True
# run a command in a pseudo-terminal, with interrupt handling, def run_command(command: list[str]) -> int:
# useful when you want to run interactive things """
def _run_with_pty_unix(command): Run a command with interrupt handling and output capture.
import pty Uses subprocess.run with direct stream piping for better performance.
import termios
master, slave = pty.openpty() Args:
command (list): The command to run.
old_settings = termios.tcgetattr(sys.stdin) Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT) original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame): def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed nonlocal ctrl_c_pressed
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
# Set up the signal handler # Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin) # Run the command with stdout/stderr piped directly to system streams
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo result = subprocess.run(
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
command, command,
stdin=slave, text=True,
stdout=slave, check=False,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
) )
return result.returncode
# Close the slave file descriptor as it's now owned by the subprocess except subprocess.SubprocessError as e:
os.close(slave) log.error(f"Subprocess error: {e}")
return 1
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
# run a command in a pseudo-terminal in windows, with interrupt handling,
def _run_with_pty_win(command):
"""
Runs a command with interactive support using subprocess directly.
"""
try:
# For shell scripts on Windows, use appropriate shell
if isinstance(command, (list, tuple)):
if command[0].endswith(".sh"):
if os.path.exists("/usr/bin/bash"): # WSL
command = ["bash"] + command
else:
# Use cmd.exe with bash while preserving all arguments
command = ["cmd.exe", "/c", "bash"] + command
process = subprocess.Popen(
command,
shell=True,
universal_newlines=True,
)
process.wait()
except Exception as e: except Exception as e:
print(f"Error: {str(e)}") log.exception(f"Unexpected error: {e}")
return 1 return 1
finally: finally:
if process and process.poll() is None: # Restore the original signal handler
process.terminate() signal.signal(signal.SIGINT, original_sigint)
process.wait()
return process.returncode
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
print("Script Output\n", result.stdout)
return result.returncode
except subprocess.CalledProcessError as e:
print("Error running script:", e)
print("Error output:", e.stderr)
return e.returncode

View file

@ -44,7 +44,9 @@ class PandasDataframeDataset:
elif self.dataset_def.source.type == "rows": elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows) self.df = pandas.DataFrame(self.dataset_def.source.rows)
else: else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") raise ValueError(
f"Unsupported dataset source type: {self.dataset_def.source.type}"
)
if self.df is None: if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
@ -108,7 +110,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return IterrowsResponse( return IterrowsResponse(
data=rows, data=rows,
next_index=end if end < len(dataset_impl) else None, next_start_index=end if end < len(dataset_impl) else None,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)

View file

@ -55,4 +55,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
), ),
), ),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
] ]

View file

@ -86,7 +86,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return IterrowsResponse( return IterrowsResponse(
data=rows, data=rows,
next_index=end if end < len(loaded_dataset) else None, next_start_index=end if end < len(loaded_dataset) else None,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
@ -98,9 +98,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset # Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None): if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"]) updated_dataset.push_to_hub(dataset_def.metadata["path"])
else: else:
raise NotImplementedError("Uploading to URL-based datasets is not supported yet") raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
response = await client.inference.chat_completion(**json_params) response = await client.inference.chat_completion(**json_params)
response = response.to_dict() return ChatCompletionResponse(
completion_message=CompletionMessage(
# temporary hack to remove the metrics from the response content=response.completion_message.content.text,
response["metrics"] = [] stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
return convert_to_pydantic(ChatCompletionResponse, response) ),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator: async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client() client = self._get_client()

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, List, Optional
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,15 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
set -euo pipefail
set -x
stack_dir=$(dirname $(dirname $THIS_DIR))
PYTHONPATH=$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short

View file

@ -1,13 +1,13 @@
version: '2' version: '2'
distribution_spec: distribution_spec:
description: Use NVIDIA NIM for running LLM inference description: Use NVIDIA NIM for running LLM inference and safety
providers: providers:
inference: inference:
- remote::nvidia - remote::nvidia
vector_io: vector_io:
- inline::faiss - inline::faiss
safety: safety:
- inline::llama-guard - remote::nvidia
agents: agents:
- inline::meta-reference - inline::meta-reference
telemetry: telemetry:
@ -15,16 +15,9 @@ distribution_spec:
eval: eval:
- inline::meta-reference - inline::meta-reference
datasetio: datasetio:
- remote::huggingface
- inline::localfs - inline::localfs
scoring: scoring:
- inline::basic - inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol
image_type: conda image_type: conda

View file

@ -6,9 +6,10 @@
from pathlib import Path from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -16,19 +17,13 @@ def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::nvidia"], "inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"], "vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"], "safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic"],
"tool_runtime": [ "tool_runtime": ["inline::rag-runtime"],
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
} }
inference_provider = Provider( inference_provider = Provider(
@ -36,30 +31,35 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia", provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(), config=NVIDIAConfig.sample_run_config(),
) )
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = { available_models = {
"nvidia": MODEL_ENTRIES, "nvidia": MODEL_ENTRIES,
} }
default_tool_groups = [ default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::rag", toolgroup_id="builtin::rag",
provider_id="rag-runtime", provider_id="rag-runtime",
), ),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
] ]
default_models = get_model_registry(available_models) default_models = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="remote_hosted", distro_type="remote_hosted",
description="Use NVIDIA NIM for running LLM inference", description="Use NVIDIA NIM for running LLM inference and safety",
container_image=None, container_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=Path(__file__).parent / "doc_template.md",
providers=providers, providers=providers,
@ -72,15 +72,34 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models, default_models=default_models,
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
}, },
run_config_env_vars={ run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"NVIDIA_API_KEY": ( "NVIDIA_API_KEY": (
"", "",
"NVIDIA API Key", "NVIDIA API Key",
), ),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
}, },
) )

View file

@ -0,0 +1,101 @@
version: '2'
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -26,10 +26,11 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: nvidia
provider_type: inline::llama-guard provider_type: remote::nvidia
config: config:
excluded_categories: [] guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -54,13 +55,6 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio: datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/huggingface_datasetio.db
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: config:
@ -72,33 +66,10 @@ providers:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {} config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime: tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
@ -227,11 +198,7 @@ datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server: server:
port: 8321 port: 8321

View file

@ -269,6 +269,7 @@ exclude = [
"^llama_stack/providers/remote/inference/together/", "^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/", "^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/", "^llama_stack/providers/remote/safety/sample/",
"^llama_stack/providers/remote/tool_runtime/bing_search/", "^llama_stack/providers/remote/tool_runtime/bing_search/",
"^llama_stack/providers/remote/tool_runtime/brave_search/", "^llama_stack/providers/remote/tool_runtime/brave_search/",

View file

@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -20,7 +21,7 @@ from llama_stack.distribution.build import (
get_provider_dependencies, get_provider_dependencies,
) )
REPO_ROOT = Path(__file__).parent.parent.parent REPO_ROOT = Path(__file__).parent.parent
class ChangedPathTracker: class ChangedPathTracker:

1
scripts/gen-changelog.py Normal file → Executable file
View file

@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #

View file

@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -18,7 +19,7 @@ import fire
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig
from llama_stack.providers.inline.inference.meta_reference.generation import Llama from llama_stack.providers.inline.inference.meta_reference.llama3.generation import Llama3
THIS_DIR = Path(__file__).parent.resolve() THIS_DIR = Path(__file__).parent.resolve()
@ -41,7 +42,7 @@ def run_main(
llama_model = resolve_model(model_id) llama_model = resolve_model(model_id)
if not llama_model: if not llama_model:
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
generator = Llama.build( generator = Llama3.build(
config=config, config=config,
model_id=model_id, model_id=model_id,
llama_model=llama_model, llama_model=llama_model,

View file

@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -15,8 +16,7 @@ Script for running api on AsyncLlamaStackAsLibraryClient with templates
Assuming directory structure: Assuming directory structure:
- llama-stack - llama-stack
- llama_stack - scripts
- scripts
- tests - tests
- api - api
@ -25,10 +25,10 @@ Example command:
cd llama-stack cd llama-stack
EXPORT TOGETHER_API_KEY=<..> EXPORT TOGETHER_API_KEY=<..>
EXPORT FIREWORKS_API_KEY=<..> EXPORT FIREWORKS_API_KEY=<..>
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report ./scripts/run_client_sdk_tests.py --templates together fireworks --report
""" """
REPO_ROOT = Path(__file__).parent.parent.parent REPO_ROOT = Path(__file__).parent.parent
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/api/" CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/api/"