Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-12 15:12:42 -05:00 committed by GitHub
commit e49bcd46fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 3142 additions and 586 deletions

View file

@ -8,29 +8,37 @@ on:
jobs: jobs:
unit-tests: unit-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.10.16' python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@v5 - uses: astral-sh/setup-uv@v5
with: with:
python-version: '3.10.16' python-version: ${{ matrix.python }}
enable-cache: false enable-cache: false
- name: Run unit tests - name: Run unit tests
run: | run: |
uv run -p 3.10.16 --with . --with ".[dev]" --with ".[test]" pytest -s -v tests/unit/ --junitxml=pytest-report.xml uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml
- name: Upload test results - name: Upload test results
if: always() if: always()
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: test-results name: test-results-${{ matrix.python }}
path: | path: |
.pytest_cache/ .pytest_cache/
pytest-report.xml pytest-report-${{ matrix.python }}.xml
retention-days: 7 retention-days: 7

1
.gitignore vendored
View file

@ -21,3 +21,4 @@ docs/src
pyrightconfig.json pyrightconfig.json
venv/ venv/
pytest-report.xml pytest-report.xml
.coverage

View file

@ -159,8 +159,7 @@ uv run sphinx-autobuild source build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command: If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash ```bash
uv sync --extra dev uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
uv run ./docs/openapi_generator/run_openapi_generator.sh
``` ```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -427,6 +427,7 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
@ -448,7 +449,40 @@
"scikit-learn", "scikit-learn",
"scipy", "scipy",
"sentencepiece", "sentencepiece",
"tqdm",
"transformers",
"uvicorn"
],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fastapi",
"fire",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec", "sqlite-vec",
"together",
"tqdm", "tqdm",
"transformers", "transformers",
"uvicorn" "uvicorn"

View file

@ -71,7 +71,6 @@ services:
condition: service_healthy condition: service_healthy
- vllm-${VLLM_SAFETY_MODEL:+safety}: - vllm-${VLLM_SAFETY_MODEL:+safety}:
condition: service_healthy condition: service_healthy
# image: llamastack/distribution-remote-vllm
image: llamastack/distribution-remote-vllm:test-0.0.52rc3 image: llamastack/distribution-remote-vllm:test-0.0.52rc3
volumes: volumes:
- ~/.llama:/root/.llama - ~/.llama:/root/.llama

View file

@ -363,6 +363,37 @@
} }
}, },
"/v1/agents": { "/v1/agents": {
"get": {
"responses": {
"200": {
"description": "A ListAgentsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List all agents.",
"parameters": []
},
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -609,6 +640,47 @@
} }
}, },
"/v1/agents/{agent_id}": { "/v1/agents/{agent_id}": {
"get": {
"responses": {
"200": {
"description": "An Agent of the agent.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Agent"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "Describe an agent by its ID.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "ID of the agent.",
"required": true,
"schema": {
"type": "string"
}
}
]
},
"delete": { "delete": {
"responses": { "responses": {
"200": { "200": {
@ -2276,6 +2348,49 @@
] ]
} }
}, },
"/v1/agents/{agent_id}/sessions": {
"get": {
"responses": {
"200": {
"description": "A ListAgentSessionsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentSessionsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List all session(s) of a given agent.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to list sessions for.",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/eval/benchmarks": { "/v1/eval/benchmarks": {
"get": { "get": {
"responses": { "responses": {
@ -6565,6 +6680,28 @@
"title": "ScoringResult", "title": "ScoringResult",
"description": "A scoring result for a single row." "description": "A scoring result for a single row."
}, },
"Agent": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"agent_config": {
"$ref": "#/components/schemas/AgentConfig"
},
"created_at": {
"type": "string",
"format": "date-time"
}
},
"additionalProperties": false,
"required": [
"agent_id",
"agent_config",
"created_at"
],
"title": "Agent"
},
"Session": { "Session": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7907,6 +8044,38 @@
], ],
"title": "ToolInvocationResult" "title": "ToolInvocationResult"
}, },
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Session"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentSessionsResponse"
},
"ListAgentsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Agent"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentsResponse"
},
"BucketResponse": { "BucketResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9321,21 +9490,11 @@
"type": "object", "type": "object",
"properties": { "properties": {
"tool_responses": { "tool_responses": {
"oneOf": [ "type": "array",
{ "items": {
"type": "array", "$ref": "#/components/schemas/ToolResponse"
"items": { },
"$ref": "#/components/schemas/ToolResponse" "description": "The tool call responses to resume the turn with."
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
}
}
],
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
}, },
"stream": { "stream": {
"type": "boolean", "type": "boolean",

View file

@ -238,6 +238,28 @@ paths:
$ref: '#/components/schemas/CompletionRequest' $ref: '#/components/schemas/CompletionRequest'
required: true required: true
/v1/agents: /v1/agents:
get:
responses:
'200':
description: A ListAgentsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: List all agents.
parameters: []
post: post:
responses: responses:
'200': '200':
@ -410,6 +432,34 @@ paths:
$ref: '#/components/schemas/CreateUploadSessionRequest' $ref: '#/components/schemas/CreateUploadSessionRequest'
required: true required: true
/v1/agents/{agent_id}: /v1/agents/{agent_id}:
get:
responses:
'200':
description: An Agent of the agent.
content:
application/json:
schema:
$ref: '#/components/schemas/Agent'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: Describe an agent by its ID.
parameters:
- name: agent_id
in: path
description: ID of the agent.
required: true
schema:
type: string
delete: delete:
responses: responses:
'200': '200':
@ -1528,6 +1578,36 @@ paths:
required: true required: true
schema: schema:
type: string type: string
/v1/agents/{agent_id}/sessions:
get:
responses:
'200':
description: A ListAgentSessionsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentSessionsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: List all session(s) of a given agent.
parameters:
- name: agent_id
in: path
description: >-
The ID of the agent to list sessions for.
required: true
schema:
type: string
/v1/eval/benchmarks: /v1/eval/benchmarks:
get: get:
responses: responses:
@ -4549,6 +4629,22 @@ components:
- aggregated_results - aggregated_results
title: ScoringResult title: ScoringResult
description: A scoring result for a single row. description: A scoring result for a single row.
Agent:
type: object
properties:
agent_id:
type: string
agent_config:
$ref: '#/components/schemas/AgentConfig'
created_at:
type: string
format: date-time
additionalProperties: false
required:
- agent_id
- agent_config
- created_at
title: Agent
Session: Session:
type: object type: object
properties: properties:
@ -5385,6 +5481,28 @@ components:
required: required:
- content - content
title: ToolInvocationResult title: ToolInvocationResult
ListAgentSessionsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Session'
additionalProperties: false
required:
- data
title: ListAgentSessionsResponse
ListAgentsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Agent'
additionalProperties: false
required:
- data
title: ListAgentsResponse
BucketResponse: BucketResponse:
type: object type: object
properties: properties:
@ -6287,16 +6405,11 @@ components:
type: object type: object
properties: properties:
tool_responses: tool_responses:
oneOf: type: array
- type: array items:
items: $ref: '#/components/schemas/ToolResponse'
$ref: '#/components/schemas/ToolResponse'
- type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >- description: >-
The tool call responses to resume the turn with. NOTE: ToolResponseMessage The tool call responses to resume the turn with.
will be deprecated. Use ToolResponse.
stream: stream:
type: boolean type: boolean
description: Whether to stream the response. description: Whether to stream the response.

View file

@ -1267,7 +1267,6 @@
} }
], ],
"source": [ "source": [
"# NBVAL_SKIP\n",
"from pydantic import BaseModel\n", "from pydantic import BaseModel\n",
"\n", "\n",
"\n", "\n",
@ -1279,7 +1278,7 @@
"\n", "\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n", "user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n", "response = client.inference.completion(\n",
" model_id=model_id,\n", " model_id=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" content=user_input,\n", " content=user_input,\n",
" stream=False,\n", " stream=False,\n",
" sampling_params={\n", " sampling_params={\n",
@ -1640,7 +1639,7 @@
"agent = Agent(\n", "agent = Agent(\n",
" client, \n", " client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant. Use websearch tool to help answer questions.\",\n",
" tools=[\"builtin::websearch\"],\n", " tools=[\"builtin::websearch\"],\n",
")\n", ")\n",
"user_prompts = [\n", "user_prompts = [\n",

View file

@ -1,9 +1 @@
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility. The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
Please install the following packages before running the script:
```
pip install fire PyYAML
```
Then simply run `sh run_openapi_generator.sh`

View file

@ -23,9 +23,12 @@ In this example, we will show you how to:
##### Building a Search Agent ##### Building a Search Agent
```python ```python
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
agent = Agent( agent = Agent(
client, client,
model="meta-llama/Llama-3.3-70B-Instruct", model="meta-llama/Llama-3.3-70B-Instruct",
@ -33,7 +36,7 @@ agent = Agent(
tools=["builtin::websearch"], tools=["builtin::websearch"],
) )
user_prompts = [ user_prompts = [
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.", "Which teams played in the NBA Western Conference Finals of 2024. Search the web for the answer.",
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.", "In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
"What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.", "What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.",
] ]

View file

@ -33,6 +33,8 @@ Can be set to any of the following log levels:
The default global log level is `info`. `all` sets the log level for all components. The default global log level is `info`. `all` sets the log level for all components.
A user can also set `LLAMA_STACK_LOG_FILE` which will pipe the logs to the specified path as well as to the terminal. An example would be: `export LLAMA_STACK_LOG_FILE=server.log`
### Llama Stack Build ### Llama Stack Build
In order to build your own distribution, we recommend you clone the `llama-stack` repository. In order to build your own distribution, we recommend you clone the `llama-stack` repository.

View file

@ -40,7 +40,6 @@ The following models are available by default:
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` - `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` - `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` - `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` - `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` - `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` - `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration. You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
### (Optional) Update Model Serving Configuration ### (Optional) Update Model Serving Configuration
```{note} ```{note}
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models. Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
``` ```
To serve a new model with `ollama` To serve a new model with `ollama`

View file

@ -1,6 +1,6 @@
# llama (server-side) CLI Reference # llama (server-side) CLI Reference
The `llama` CLI tool helps you setup and use the Llama Stack. It should be available on your path after installing the `llama-stack` package. The `llama` CLI tool helps you set up and use the Llama Stack. The CLI is available on your path after installing the `llama-stack` package.
## Installation ## Installation
@ -27,9 +27,9 @@ You have two ways to install Llama Stack:
## `llama` subcommands ## `llama` subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face. 1. `download`: Supports downloading models from Meta or Hugging Face. [Downloading models](#downloading-models)
2. `model`: Lists available models and their properties. 2. `model`: Lists available models and their properties. [Understanding models](#understand-the-models)
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](../../distributions/building_distro). 3. `stack`: Allows you to build a stack using the `llama stack` distribution and run a Llama Stack server. You can read more about how to build a Llama Stack distribution in the [Build your own Distribution](../../distributions/building_distro) documentation.
### Sample Usage ### Sample Usage
@ -117,7 +117,7 @@ You should see a table like this:
+----------------------------------+------------------------------------------+----------------+ +----------------------------------+------------------------------------------+----------------+
``` ```
To download models, you can use the llama download command. To download models, you can use the `llama download` command.
### Downloading from [Meta](https://llama.meta.com/llama-downloads/) ### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
@ -191,7 +191,7 @@ You should see a table like this:
The `llama model` command helps you explore the models interface. The `llama model` command helps you explore the models interface.
1. `download`: Download the model from different sources. (meta, huggingface) 1. `download`: Download the model from different sources. (meta, huggingface)
2. `list`: Lists all the models available for download with hardware requirements to deploy the models. 2. `list`: Lists all the models available for download with hardware requirements for deploying the models.
3. `prompt-format`: Show llama model message formats. 3. `prompt-format`: Show llama model message formats.
4. `describe`: Describes all the properties of the model. 4. `describe`: Describes all the properties of the model.
@ -262,13 +262,12 @@ llama model prompt-format -m Llama3.2-3B-Instruct
![alt text](../../../resources/prompt-format.png) ![alt text](../../../resources/prompt-format.png)
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios. You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens. **NOTE**: Outputs in terminal are color printed to show special tokens.
### Remove model ### Remove model
You can run `llama model remove` to remove unecessary model: You can run `llama model remove` to remove an unnecessary model:
``` ```
llama model remove -m Llama-Guard-3-8B-int8 llama model remove -m Llama-Guard-3-8B-int8

View file

@ -40,7 +40,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
ollama run llama3.2:3b-instruct-fp16 --keepalive -1m ollama run llama3.2:3b-instruct-fp16 --keepalive -1m
``` ```
**Note**: **Note**:
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43) - The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py)
- `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again. - `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again.
--- ---

View file

@ -234,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
@json_schema_type
class Agent(BaseModel):
agent_id: str
agent_config: AgentConfig
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None instructions: Optional[str] = None
@ -353,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str agent_id: str
session_id: str session_id: str
turn_id: str turn_id: str
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]] tool_responses: List[ToolResponse]
stream: Optional[bool] = False stream: Optional[bool] = False
@ -432,7 +449,7 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], tool_responses: List[ToolResponse],
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses. """Resume an agent turn with executed tool call responses.
@ -443,7 +460,6 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume. :param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume. :param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with. :param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response. :param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
""" """
@ -541,3 +557,32 @@ class Agents(Protocol):
:param agent_id: The ID of the agent to delete. :param agent_id: The ID of the agent to delete.
""" """
... ...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
"""List all agents.
:returns: A ListAgentsResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET")
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
"""
...

View file

@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponse(BaseModel): class CompletionResponse(MetricResponseMixin):
"""Response from a completion request. """Response from a completion request.
:param content: The generated completion text :param content: The generated completion text
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponseStreamChunk(BaseModel): class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response. """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens. :param delta: New content generated since last chunk. This can be one or more tokens.
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response. """A chunk of a streamed chat completion response.
:param event: The event containing the new content :param event: The event containing the new content
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel): class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request. """Response from a chat completion request.
:param completion_message: The complete response message :param completion_message: The complete response message

View file

@ -39,7 +39,7 @@ from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
if build_config.image_type == ImageType.container.value and not args.image_name: if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
cprint( cprint(
"Please specify --image-name when building a container from a config file", "Please specify --image-name when building a container from a config file",
color="red", color="red",
@ -226,7 +226,7 @@ def _generate_run_config(
""" """
apis = list(build_config.distribution_spec.providers.keys()) apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig( run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == ImageType.container.value else None), container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name, image_name=image_name,
apis=apis, apis=apis,
providers={}, providers={},
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
template_name: Optional[str] = None, template_name: Optional[str] = None,
config_path: Optional[str] = None, config_path: Optional[str] = None,
) -> str: ) -> str:
if build_config.image_type == ImageType.container.value: if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name: if template_name:
image_name = f"distribution-{template_name}" image_name = f"distribution-{template_name}"
else: else:
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a container image without a template") raise ValueError("Please specify an image name when building a container image without a template")
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == LlamaStackImageType.CONDA.value:
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a conda image") raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value: elif build_config.image_type == LlamaStackImageType.VENV.value:
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"): if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
image_name = "__system__" image_name = "__system__"
if not image_name: if not image_name:

View file

@ -16,7 +16,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -95,7 +95,7 @@ def build_image(
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers) normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value: if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [ args = [
script, script,
@ -104,7 +104,7 @@ def build_image(
container_base, container_base,
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [ args = [
script, script,
@ -112,7 +112,7 @@ def build_image(
str(build_file_path), str(build_file_path),
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.venv.value: elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [ args = [
script, script,

View file

@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields, redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace, end_trace,
setup_logger, setup_logger,
start_trace, start_trace,
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally: finally:
await end_trace() await end_trace()
# Wrap the generator to preserve context across iterations wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
wrapped_gen = preserve_headers_context_async_generator(gen())
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=wrapped_gen, content=wrapped_gen,

View file

@ -7,14 +7,14 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar from typing import Any, ContextManager, Dict, Optional
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Context variable for request provider data # Context variable for request provider data
_provider_data_var = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager): class RequestProviderDataContext(ContextManager):
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
def __enter__(self): def __enter__(self):
# Save the current value and set the new one # Save the current value and set the new one
self.token = _provider_data_var.set(self.provider_data) self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value # Restore the previous value
if self.token is not None: if self.token is not None:
_provider_data_var.reset(self.token) PROVIDER_DATA_VAR.reset(self.token)
T = TypeVar("T")
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value right now
context_value = _provider_data_var.get()
async def wrapper():
while True:
# Set context before each anext() call
_ = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
class NeedsRequestProviderData: class NeedsRequestProviderData:
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
if not validator_class: if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator") raise ValueError(f"Provider {provider_type} does not have a validator")
val = _provider_data_var.get() val = PROVIDER_DATA_VAR.get()
if not val: if not val:
return None return None

View file

@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api, routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api], api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value], # Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
), ),
) )
} }

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter, "eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter, "tool_runtime": ToolRuntimeRouter,
} }
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table) api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
JobStatus, JobStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
StopReason,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter") logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize") logger.debug("InferenceRouter.initialize")
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -156,8 +236,9 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug( logger.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
if sampling_params is None: if sampling_params is None:
@ -206,10 +287,47 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else: else:
return await provider.chat_completion(**params) response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion( async def completion(
self, self,
@ -239,10 +357,41 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
prompt_tokens = await self._count_tokens(content)
if stream: if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else: else:
return await provider.completion(**params) response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings( async def embeddings(
self, self,

View file

@ -6,11 +6,9 @@
import argparse import argparse
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import os import os
import signal
import sys import sys
import traceback import traceback
import warnings import warnings
@ -30,7 +28,7 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
@ -40,6 +38,7 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
@ -47,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter, TelemetryAdapter,
) )
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace, end_trace,
setup_logger, setup_logger,
start_trace, start_trace,
@ -118,69 +118,24 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
) )
def handle_signal(app, signum, _) -> None: async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
Handle incoming signals and initiate a graceful shutdown of the application. for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
This function is intended to be used as a signal handler for various signals logger.info("Shutting down %s", impl_name)
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try: try:
# Gracefully shut down implementations if hasattr(impl, "shutdown"):
for impl in app.__llama_stack_impls__.values(): await asyncio.wait_for(impl.shutdown(), timeout=5)
impl_name = impl.__class__.__name__ else:
logger.info("Shutting down %s", impl_name) logger.warning("No shutdown method for %s", impl_name)
try: except asyncio.TimeoutError:
if hasattr(impl, "shutdown"): logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
await asyncio.wait_for(impl.shutdown(), timeout=5) except (Exception, asyncio.CancelledError) as e:
else: logger.exception("Failed to shutdown %s: %s", impl_name, {e})
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager @asynccontextmanager
@ -188,8 +143,7 @@ async def lifespan(app: FastAPI):
logger.info("Starting up") logger.info("Starting up")
yield yield
logger.info("Shutting down") logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values(): await shutdown(app)
await impl.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -230,13 +184,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
try: try:
if is_streaming: if is_streaming:
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream") return StreamingResponse(gen, media_type="text/event-stream")
else: else:
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
except Exception as e: except Exception as e:
logger.exception("Error executing endpoint %s", method, route) logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -266,7 +222,7 @@ class TracingMiddleware:
self.app = app self.app = app
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
path = scope["path"] path = scope.get("path", "")
await start_trace(path, {"__location__": "server"}) await start_trace(path, {"__location__": "server"})
try: try:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
@ -439,8 +395,6 @@ def main():
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
@ -471,6 +425,8 @@ def main():
"app": app, "app": app,
"host": listen_host, "host": listen_host,
"port": port, "port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
} }
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
async def wrapper():
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
except StopAsyncIteration:
break
return wrapper()

View file

@ -20,14 +20,14 @@ import importlib
import json import json
from pathlib import Path from pathlib import Path
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list: def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = "" env_name = ""
if image_type == ImageType.container.value or config.container_image: if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image env_name = f"distribution-{template_name}" if template_name else config.container_image
elif image_type == ImageType.conda.value: elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env env_name = image_name or current_conda_env
if not env_name: if not env_name:

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum import enum
class ImageType(Enum): class LlamaStackImageType(enum.Enum):
container = "container" CONTAINER = "container"
conda = "conda" CONDA = "conda"
venv = "venv" VENV = "venv"

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
token = context_var.set("start_value")
# Create an async generator that raises an exception
async def exception_generator():
yield context_var.get()
context_var.set("modified")
raise ValueError("Test exception")
yield None # This will never be reached
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
# First iteration should work
value = await wrapped_gen.__anext__()
assert value == "start_value"
# Second iteration should raise the exception
with pytest.raises(ValueError, match="Test exception"):
await wrapped_gen.__anext__()
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
token = context_var.set("value")
# Create an empty async generator
async def empty_generator():
if False: # This condition ensures the generator yields nothing
yield None
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
# The generator should raise StopAsyncIteration immediately
with pytest.raises(StopAsyncIteration):
await wrapped_gen.__anext__()
# Context variable should remain unchanged
assert context_var.get() == "value"
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.
This simulates the real-world scenario where:
1. A new event loop is created for each streaming request
2. The async generator runs inside that loop
3. There are multiple levels of nested generators
4. Context needs to be preserved across these boundaries
"""
# Create context variables
request_id = ContextVar("request_id", default=None)
user_id = ContextVar("user_id", default=None)
# Set initial values
# Results container to verify values across thread boundaries
results = []
# Inner-most generator (level 2)
async def inner_generator():
# Should have the context from the outer scope
yield (1, request_id.get(), user_id.get())
# Modify one context variable
user_id.set("user-modified")
# Should reflect the modification
yield (2, request_id.get(), user_id.get())
# Middle generator (level 1)
async def middle_generator():
inner_gen = inner_generator()
# Forward the first yield from inner
item = await inner_gen.__anext__()
yield item
# Forward the second yield from inner
item = await inner_gen.__anext__()
yield item
request_id.set("req-modified")
# Add our own yield with both modified variables
yield (3, request_id.get(), user_id.get())
# Function to run in a separate thread with a new event loop
def run_in_new_loop():
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Outer generator (runs in the new loop)
async def outer_generator():
request_id.set("req-12345")
user_id.set("user-6789")
# Wrap the middle generator
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
# Process all items from the middle generator
async for item in wrapped_gen:
# Store results for verification
results.append(item)
# Run the outer generator in the new loop
loop.run_until_complete(outer_generator())
finally:
loop.close()
# Run the generator chain in a separate thread with a new event loop
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
future.result() # Wait for completion
# Verify the results
assert len(results) == 3
# First yield should have original values
assert results[0] == (1, "req-12345", "user-6789")
# Second yield should have modified user_id
assert results[1] == (2, "req-12345", "user-modified")
# Third yield should have both modified values
assert results[2] == (3, "req-modified", "user-modified")

View file

@ -12,6 +12,7 @@ from typing import Dict
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError
from rich.logging import RichHandler from rich.logging import RichHandler
from termcolor import cprint
# Default log level # Default log level
DEFAULT_LOG_LEVEL = logging.INFO DEFAULT_LOG_LEVEL = logging.INFO
@ -96,12 +97,13 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup self.markup = original_markup
def setup_logging(category_levels: Dict[str, int]) -> None: def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
""" """
Configure logging based on the provided category log levels. Configure logging based on the provided category log levels and an optional log file.
Parameters: Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels. category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
log_file (str): Path to a log file to additionally pipe the logs into
""" """
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s" log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
@ -116,6 +118,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
# Determine the root logger's level (default to WARNING if not specified) # Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING) root_level = category_levels.get("root", logging.WARNING)
handlers = {
"console": {
"()": CustomRichHandler, # Use custom console handler
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
}
# Add a file handler if log_file is set
if log_file:
handlers["file"] = {
"class": "logging.FileHandler",
"formatter": "rich",
"filename": log_file,
"mode": "a",
"encoding": "utf-8",
}
logging_config = { logging_config = {
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,
@ -125,17 +149,7 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
"format": log_format, "format": log_format,
} }
}, },
"handlers": { "handlers": handlers,
"console": {
"()": CustomRichHandler, # Use our custom handler class
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
},
"filters": { "filters": {
"category_filter": { "category_filter": {
"()": CategoryFilter, "()": CategoryFilter,
@ -143,19 +157,24 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
}, },
"loggers": { "loggers": {
category: { category: {
"handlers": ["console"], "handlers": list(handlers.keys()), # Apply all handlers
"level": category_levels.get(category, DEFAULT_LOG_LEVEL), "level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger "propagate": False, # Disable propagation to root logger
} }
for category in CATEGORIES for category in CATEGORIES
}, },
"root": { "root": {
"handlers": ["console"], "handlers": list(handlers.keys()),
"level": root_level, # Set root logger's level dynamically "level": root_level, # Set root logger's level dynamically
}, },
} }
dictConfig(logging_config) dictConfig(logging_config)
# Ensure third-party libraries follow the root log level
for _, logger in logging.root.manager.loggerDict.items():
if isinstance(logger, logging.Logger):
logger.setLevel(root_level)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter: def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
""" """
@ -176,7 +195,9 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
env_config = os.environ.get("LLAMA_STACK_LOGGING", "") env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config: if env_config:
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}") cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
_category_levels.update(parse_environment_config(env_config)) _category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels) log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
setup_logging(_category_levels, log_file)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(

View file

@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span: async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span: async with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id) span.set_attribute("turn_id", request.turn_id)
@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
if isinstance(request.tool_responses[0], ToolResponseMessage): tool_response_messages = [
tool_response_messages = request.tool_responses ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
tool_responses = [ for x in request.tool_responses
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content) ]
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
messages.extend(tool_response_messages) messages.extend(tool_response_messages)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = self.turn_to_messages(last_turn)
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
) )
@ -390,7 +382,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str], shields: List[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.span("run_shields") as span: async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
span.set_attribute("output", "no shields") span.set_attribute("output", "no shields")
@ -508,7 +500,7 @@ class ChatAgent(ShieldRunnerMixin):
content = "" content = ""
stop_reason = None stop_reason = None
with tracing.span("inference") as span: async with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
@ -685,7 +677,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name = tool_call.tool_name tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value tool_name = tool_name.value
with tracing.span( async with tracing.span(
"tool_execution", "tool_execution",
{ {
"tool_name": tool_name, "tool_name": tool_name,

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent,
AgentConfig, AgentConfig,
AgentCreateResponse, AgentCreateResponse,
Agents, Agents,
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Document, Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session, Session,
Turn, Turn,
) )
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id, agent_id=agent_id,
) )
async def get_agent(self, agent_id: str) -> ChatAgent: async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get( agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
) )
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgentSessionCreateResponse: ) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id) agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name) session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse( return AgentSessionCreateResponse(
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id) agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event
@ -169,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], tool_responses: List[ToolResponse],
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnResumeRequest( request = AgentTurnResumeRequest(
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
request: AgentTurnResumeRequest, request: AgentTurnResumeRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id) agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request): async for event in agent.resume_turn(request):
yield event yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id) agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id) turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn return turn
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ) -> Session:
agent = await self.get_agent(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -10,6 +10,7 @@ from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None: async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather( async def run_shield_with_span(identifier: str):
*[ async with tracing.span(f"run_shield_{identifier}"):
self.safety_api.run_shield( return await self.safety_api.run_shield(
shield_id=identifier, shield_id=identifier,
messages=messages, messages=messages,
) )
for identifier in identifiers
] responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
)
for identifier, response in zip(identifiers, responses, strict=False): for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation: if not response.violation:
continue continue

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
async def get_provider_impl( async def get_provider_impl(
config: LocalFSDatasetIOConfig, config: LocalFSDatasetIOConfig,
_deps, _deps: Dict[str, Any],
): ):
from .datasetio import LocalFSDatasetIOImpl from .datasetio import LocalFSDatasetIOImpl

View file

@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url) url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url) parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme: if parsed_url.scheme == "file" or not parsed_url.scheme:

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .eval import MetaReferenceEvalImpl from .eval import MetaReferenceEvalImpl

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Union from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl( async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps, _deps: Dict[str, Any],
): ):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import ( from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl( async def get_provider_impl(
config: SentenceTransformersInferenceConfig, config: SentenceTransformersInferenceConfig,
_deps, _deps: Dict[str, Any],
): ):
from .sentence_transformers import SentenceTransformersInferenceImpl from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Dict
from .config import VLLMConfig from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any: async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config) impl = VLLMInferenceImpl(config)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl( async def get_provider_impl(
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .post_training import TorchtunePostTrainingImpl from .post_training import TorchtunePostTrainingImpl

View file

@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
self.jobs = {} self.jobs = {}
self.checkpoints_dict = {} self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps): async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps) impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps): async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401 from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps): async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps) impl = PromptGuardSafetyImpl(config, deps)

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig from .config import BasicScoringConfig
async def get_provider_impl( async def get_provider_impl(
config: BasicScoringConfig, config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .scoring import BasicScoringImpl from .scoring import BasicScoringImpl

View file

@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
class BasicScoringImpl( class BasicScoringImpl(

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
regex_parser_math_response = ScoringFn(
identifier="basic::regex_parser_math_response",
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-math-response",
params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import (
regex_parser_math_response,
)
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_math_response.identifier: regex_parser_math_response,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, (
"Only one parsing regex is supported for regex_parser_math_response scoring function."
)
parsing_regexes = fn_def.params.parsing_regexes[0]
normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer),
parsing_regexes,
match_first=True,
)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,330 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Sequence
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
# from minerva
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
if isinstance(expression, float):
return expression
new_expression = f"{expression}"
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
for match in re.finditer(regex, expression):
try:
value = float(match.group(1)) / float(match.group(2))
new_expression = new_expression.replace(
match.group(),
f"{{value:{fmt}}}".format(value=value),
1,
)
except Exception:
continue
return new_expression
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
try:
with time_limit(seconds=5):
from sympy.parsing.latex import parse_latex
value = parse_latex(expression).evalf() # type: ignore
return f"{{value:{fmt}}}".format(value=value)
except Exception:
return expression
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
for marker in markers:
text = text.split(marker)[0]
return text
def extract_result_from_boxed(answer: str) -> str:
box_start = "\\boxed"
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
start = answer.rfind(box_start)
if start < 0:
return ""
answer = answer[start + len(box_start) :].strip()
ends_with_curly = answer.startswith("{")
i = 0
open_braces = 0
while i < len(answer):
if answer[i] == "{":
open_braces += 1
elif answer[i] == "}":
open_braces -= 1
if open_braces == 0:
if ends_with_curly:
answer = answer[: i + 1].strip()
break
elif answer[i] == "$":
answer = answer[:i].strip()
break
i += 1
else:
return ""
# remove extra curly braces
while True:
if answer.startswith("{") and answer.endswith("}"):
answer = answer[1:-1].strip()
else:
break
return answer
# from minerva paper + _normalise_result from xavierm
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
"""Extract and normalize a final answer to a quantitative reasoning question."""
match = re.findall(regex_pattern, final_answer)
extraction: str
if len(match) > 0:
if match_first:
extraction = match[0]
else:
extraction = match[-1]
else:
extraction = extract_result_from_boxed(final_answer)
if len(extraction) == 0:
return final_answer
else:
final_answer = extraction
final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
# If the final answer is a single letter in parentheses, remove the parentheses
# Example: (a) -> a (but not (ab) -> ab)
if re.match(r"\([a-zA-Z]\)", final_answer):
final_answer = final_answer[1]
return _normalise_result(final_answer)
def _normalise_result(string: str) -> str:
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("cfrac", "frac")
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\le", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace(r"\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
string = string.split("=")[-1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def _remove_right_units(string: str) -> str:
# "\\text{ " only ever occurs (at least in the val set) when describing units
try:
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
except AssertionError:
return string
def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if len(split) == 0:
return string
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _fix_fracs(string: str) -> str:
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) == 0:
return string
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string: str) -> str:
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
ia = int(a)
ib = int(b)
assert string == "{}/{}".format(ia, ib)
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
return new_string
except (ValueError, AssertionError):
return string

View file

@ -3,11 +3,11 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl( async def get_provider_impl(
config: BraintrustScoringConfig, config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .braintrust import BraintrustScoringImpl from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
async def get_provider_impl( async def get_provider_impl(
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .scoring import LlmAsJudgeScoringImpl from .scoring import LlmAsJudgeScoringImpl

View file

@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps.get(Api.datasetio) self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create( resource = Resource.create(
{ {
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name] return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
if self.meter is None:
return
if isinstance(event.value, int): if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes) counter.add(event.value, attributes=event.attributes)

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"] __all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps): async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config) impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import ( from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter, ChromaVectorIOAdapter,
) )

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
), ),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
# source distribution and the wheels are not available for all platforms.
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::sqlite-vec", provider_type="inline::sqlite-vec",

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct", "accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct", "accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import LlamaStackClient from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model
def _get_client(self) -> LlamaStackClient: def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None passthrough_url = None
passthrough_api_key = None passthrough_api_key = None
provider_data = None provider_data = None
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
) )
passthrough_api_key = provider_data.passthrough_api_key passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient( return AsyncLlamaStackClient(
base_url=passthrough_url, base_url=passthrough_url,
api_key=passthrough_api_key, api_key=passthrough_api_key,
provider_data=provider_data, provider_data=provider_data,
@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
params = { request_params = {
"model_id": model.provider_resource_id, "model_id": model.provider_resource_id,
"content": content, "content": content,
"sampling_params": sampling_params, "sampling_params": sampling_params,
@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs, "logprobs": logprobs,
} }
params = {key: value for key, value in params.items() if value is not None} request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params # only pass through the not None params
return client.inference.completion(**params) return await client.inference.completion(**json_params)
async def chat_completion( async def chat_completion(
self, self,
@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
params = { # TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id, "model_id": model.provider_resource_id,
"messages": messages, "messages": messages,
"sampling_params": sampling_params, "sampling_params": sampling_params,
@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs, "logprobs": logprobs,
} }
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params # only pass through the not None params
return client.inference.chat_completion(**params) request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings( async def embeddings(
self, self,
@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
return client.inference.embeddings( return await client.inference.embeddings(
model_id=model.provider_resource_id, model_id=model.provider_resource_id,
contents=contents, contents=contents,
text_truncation=text_truncation, text_truncation=text_truncation,
output_dimension=output_dimension, output_dimension=output_dimension,
task_type=task_type, task_type=task_type,
) )
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return { return {
"url": "https://api.together.xyz/v1", "url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}", "api_key": "${env.TOGETHER_API_KEY:}",
} }

View file

@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import AsyncTogether
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config self.config = config
self._client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if self._client:
await self._client.close()
self._client = None
async def completion( async def completion(
self, self,
@ -91,35 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_client(self) -> Together: def _get_client(self) -> AsyncTogether:
together_api_key = None if not self._client:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None together_api_key = None
if config_api_key: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
together_api_key = config_api_key if config_api_key:
else: together_api_key = config_api_key
provider_data = self.get_request_provider_data() else:
if provider_data is None or not provider_data.together_api_key: provider_data = self.get_request_provider_data()
raise ValueError( if provider_data is None or not provider_data.together_api_key:
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}' raise ValueError(
) 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
together_api_key = provider_data.together_api_key )
return Together(api_key=together_api_key) together_api_key = provider_data.together_api_key
self._client = AsyncTogether(api_key=together_api_key)
return self._client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self._get_client().completions.create(**params) client = self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
# if we shift to TogetherAsyncClient, we won't need this wrapper stream = await client.completions.create(**params)
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
@ -184,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
client = self._get_client()
if "messages" in params: if "messages" in params:
r = self._get_client().chat.completions.create(**params) r = await client.chat.completions.create(**params)
else: else:
r = self._get_client().completions.create(**params) r = await client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
@ -240,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings" "Together does not support media for embeddings"
) )
r = self._get_client().embeddings.create( client = self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )

View file

@ -615,6 +615,14 @@ def convert_tool_call(
return valid_tool_call return valid_tool_call
PYTHON_TYPE_TO_LITELLM_TYPE = {
"int": "integer",
"float": "number",
"bool": "boolean",
"str": "string",
}
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
""" """
Convert a ToolDefinition to an OpenAI API-compatible dictionary. Convert a ToolDefinition to an OpenAI API-compatible dictionary.
@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"] properties = parameters["properties"]
required = [] required = []
for param_name, param in tool.parameters.items(): for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type} properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
if param.description: if param.description:
properties[param_name].update(description=param.description) properties[param_name].update(description=param.description)
if param.default: if param.default:

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextlib
import signal
from types import FrameType
from typing import Iterator, Optional
class TimeoutError(Exception):
pass
@contextlib.contextmanager
def time_limit(seconds: float) -> Iterator[None]:
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
raise TimeoutError("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import base64 import base64
import contextvars
import logging import logging
import queue import queue
import threading import threading
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
Telemetry, Telemetry,
UnstructuredLogEvent, UnstructuredLogEvent,
) )
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__) logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8): def generate_short_uuid(len: int = 8):
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
return encoded.rstrip(b"=").decode("ascii")[:len] return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None BACKGROUND_LOGGER = None
@ -51,7 +53,7 @@ class BackgroundLogger:
try: try:
self.log_queue.put_nowait(event) self.log_queue.put_nowait(event)
except queue.Full: except queue.Full:
log.error("Log queue is full, dropping event") logger.error("Log queue is full, dropping event")
def _process_logs(self): def _process_logs(self):
while True: while True:
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api) BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger() root_logger = logging.getLogger()
logger.setLevel(level) root_logger.setLevel(level)
logger.addHandler(TelemetryHandler()) root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext: async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
log.info("No Telemetry implementation set. Skipping trace initialization...") logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return return
trace_id = generate_short_uuid(16) trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id) context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})}) context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context CURRENT_TRACE_CONTEXT.set(context)
return context return context
async def end_trace(status: SpanStatus = SpanStatus.OK): async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context is None: if context is None:
logger.debug("No trace context to end")
return return
context.pop_span(status) context.pop_span(status)
CURRENT_TRACE_CONTEXT = None CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity: def severity(levelname: str) -> LogSeverity:
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized") raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context is None: if context is None:
return return
@ -218,16 +221,22 @@ class SpanContextManager:
def __enter__(self): def __enter__(self):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context: if not context:
self.span = context.push_span(self.name, self.attributes) logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context: if not context:
context.pop_span() logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any): def set_attribute(self, key: str, value: Any):
if self.span: if self.span:
@ -237,16 +246,22 @@ class SpanContextManager:
async def __aenter__(self): async def __aenter__(self):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context: if not context:
self.span = context.push_span(self.name, self.attributes) logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self return self
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get()
if context: if not context:
context.pop_span() logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable): def __call__(self, func: Callable):
@wraps(func) @wraps(func)
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
def get_current_span() -> Optional[Span]: def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context: if context:
return context.get_current_span() return context.get_current_span()
return None return None

View file

@ -120,16 +120,6 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks provider_id: fireworks

View file

@ -178,16 +178,6 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks provider_id: fireworks

View file

@ -132,16 +132,6 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks provider_id: fireworks

View file

@ -126,16 +126,6 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks provider_id: fireworks

View file

@ -5,7 +5,7 @@ distribution_spec:
inference: inference:
- remote::ollama - remote::ollama
vector_io: vector_io:
- inline::sqlite-vec - inline::faiss
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: safety:

View file

@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
### (Optional) Update Model Serving Configuration ### (Optional) Update Model Serving Configuration
```{note} ```{note}
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models. Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
``` ```
To serve a new model with `ollama` To serve a new model with `ollama`

View file

@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
ShieldInput, ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::ollama"], "inference": ["remote::ollama"],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"], "safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
@ -43,10 +43,10 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama", provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(), config=OllamaImplConfig.sample_run_config(),
) )
vector_io_provider_sqlite = Provider( vector_io_provider_faiss = Provider(
provider_id="sqlite-vec", provider_id="faiss",
provider_type="inline::sqlite-vec", provider_type="inline::faiss",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
) )
inference_model = ModelInput( inference_model = ModelInput(
@ -96,7 +96,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite], "vector_io": [vector_io_provider_faiss],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
@ -104,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite], "vector_io": [vector_io_provider_faiss],
"safety": [ "safety": [
Provider( Provider(
provider_id="llama-guard", provider_id="llama-guard",

View file

@ -17,10 +17,13 @@ providers:
config: config:
url: ${env.OLLAMA_URL:http://localhost:11434} url: ${env.OLLAMA_URL:http://localhost:11434}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: faiss
provider_type: inline::sqlite-vec provider_type: inline::faiss
config: config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -17,10 +17,13 @@ providers:
config: config:
url: ${env.OLLAMA_URL:http://localhost:11434} url: ${env.OLLAMA_URL:http://localhost:11434}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: faiss
provider_type: inline::sqlite-vec provider_type: inline::faiss
config: config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .open_benchmark import get_distribution_template # noqa: F401

View file

@ -0,0 +1,300 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Tuple
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
DatasetInput,
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
"openai",
[
ProviderModelEntry(
provider_model_id="openai/gpt-4o",
model_type=ModelType.llm,
)
],
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
),
(
"anthropic",
[
ProviderModelEntry(
provider_model_id="anthropic/claude-3-5-sonnet-latest",
model_type=ModelType.llm,
)
],
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
),
(
"gemini",
[
ProviderModelEntry(
provider_model_id="gemini/gemini-1.5-flash",
model_type=ModelType.llm,
)
],
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
),
(
"groq",
[],
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
),
(
"together",
[],
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": [p.provider_type for p in inference_providers],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "open-benchmark"
vector_io_providers = [
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models) + [
ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq",
provider_model_id="groq/llama-3.3-70b-versatile",
model_type=ModelType.llm,
),
ModelInput(
model_id="meta-llama/Llama-3.1-405B-Instruct",
provider_id="together",
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
model_type=ModelType.llm,
),
]
default_datasets = [
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
metadata={
"path": "llamastack/mmlu_cot",
"name": "all",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
metadata={
"path": "llamastack/gpqa_0shot_cot",
"name": "gpqa_main",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
metadata={
"path": "llamastack/math_500",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
]
default_benchmarks = [
BenchmarkInput(
benchmark_id="meta-reference-simpleqa",
dataset_id="simpleqa",
scoring_functions=["llm-as-judge::405b-simpleqa"],
),
BenchmarkInput(
benchmark_id="meta-reference-mmlu-cot",
dataset_id="mmlu_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-gpqa-cot",
dataset_id="gpqa_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-math-500",
dataset_id="math_500",
scoring_functions=["basic::regex_parser_math_response"],
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running open benchmarks",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers,
"vector_io": vector_io_providers,
},
default_models=default_models,
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_datasets=default_datasets,
default_benchmarks=default_benchmarks,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"TOGETHER_API_KEY": (
"",
"Together API Key",
),
"OPENAI_API_KEY": (
"",
"OpenAI API Key",
),
"GEMINI_API_KEY": (
"",
"Gemini API Key",
),
"ANTHROPIC_API_KEY": (
"",
"Anthropic API Key",
),
"GROQ_API_KEY": (
"",
"Groq API Key",
),
},
)

View file

@ -33,12 +33,12 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb} - provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
@ -62,14 +62,14 @@ providers:
persistence_store: persistence_store:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
telemetry: telemetry:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack} service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -114,18 +114,13 @@ providers:
config: {} config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o
provider_id: openai provider_id: openai
provider_model_id: openai/gpt-4o provider_model_id: openai/gpt-4o
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
- metadata: {} - metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic provider_id: anthropic
@ -141,66 +136,95 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []
datasets: datasets:
- dataset_id: simpleqa - dataset_schema:
provider_id: huggingface input_query:
url: type: string
uri: https://huggingface.co/datasets/llamastack/simpleqa expected_answer:
metadata: type: string
path: llamastack/simpleqa chat_completion_input:
name: type: string
split: train url:
dataset_schema: uri: https://huggingface.co/datasets/llamastack/simpleqa
input_query: metadata:
type: string path: llamastack/simpleqa
expected_answer: split: train
type: string dataset_id: simpleqa
chat_completion_input: provider_id: huggingface
type: string - dataset_schema:
- dataset_id: mmlu_cot input_query:
provider_id: huggingface type: string
url: expected_answer:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot type: string
metadata: chat_completion_input:
path: llamastack/mmlu_cot type: string
name: all url:
split: test uri: https://huggingface.co/datasets/llamastack/mmlu_cot
dataset_schema: metadata:
input_query: path: llamastack/mmlu_cot
type: string name: all
expected_answer: split: test
type: string dataset_id: mmlu_cot
chat_completion_input: provider_id: huggingface
type: string - dataset_schema:
- dataset_id: gpqa_cot input_query:
provider_id: huggingface type: string
url: expected_answer:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot type: string
metadata: chat_completion_input:
path: llamastack/gpqa_0shot_cot type: string
name: gpqa_main url:
split: train uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
dataset_schema: metadata:
input_query: path: llamastack/gpqa_0shot_cot
type: string name: gpqa_main
expected_answer: split: train
type: string dataset_id: gpqa_cot
chat_completion_input: provider_id: huggingface
type: string - dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
split: test
dataset_id: math_500
provider_id: huggingface
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- benchmark_id: meta-reference-simpleqa - dataset_id: simpleqa
dataset_id: simpleqa scoring_functions:
scoring_functions: ["llm-as-judge::405b-simpleqa"] - llm-as-judge::405b-simpleqa
- benchmark_id: meta-reference-mmlu-cot metadata: {}
dataset_id: mmlu_cot benchmark_id: meta-reference-simpleqa
scoring_functions: ["basic::regex_parser_multiple_choice_answer"] - dataset_id: mmlu_cot
- benchmark_id: meta-reference-gpqa-cot scoring_functions:
dataset_id: gpqa_cot - basic::regex_parser_multiple_choice_answer
scoring_functions: ["basic::regex_parser_multiple_choice_answer"] metadata: {}
benchmark_id: meta-reference-mmlu-cot
- dataset_id: gpqa_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search

View file

@ -14,7 +14,9 @@ from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
Api, Api,
BenchmarkInput,
BuildConfig, BuildConfig,
DatasetInput,
DistributionSpec, DistributionSpec,
ModelInput, ModelInput,
Provider, Provider,
@ -28,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]: def get_model_registry(
available_models: Dict[str, List[ProviderModelEntry]],
) -> List[ModelInput]:
models = [] models = []
for provider_id, entries in available_models.items(): for provider_id, entries in available_models.items():
for entry in entries: for entry in entries:
@ -56,6 +60,8 @@ class RunConfigSettings(BaseModel):
default_models: Optional[List[ModelInput]] = None default_models: Optional[List[ModelInput]] = None
default_shields: Optional[List[ShieldInput]] = None default_shields: Optional[List[ShieldInput]] = None
default_tool_groups: Optional[List[ToolGroupInput]] = None default_tool_groups: Optional[List[ToolGroupInput]] = None
default_datasets: Optional[List[DatasetInput]] = None
default_benchmarks: Optional[List[BenchmarkInput]] = None
def run_config( def run_config(
self, self,
@ -113,6 +119,8 @@ class RunConfigSettings(BaseModel):
models=self.default_models or [], models=self.default_models or [],
shields=self.default_shields or [], shields=self.default_shields or [],
tool_groups=self.default_tool_groups or [], tool_groups=self.default_tool_groups or [],
datasets=self.default_datasets or [],
benchmarks=self.default_benchmarks or [],
) )
@ -187,7 +195,7 @@ class DistributionTemplate(BaseModel):
default_models.append( default_models.append(
DefaultModel( DefaultModel(
model_id=model_entry.provider_model_id, model_id=model_entry.provider_model_id,
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "", doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
) )
) )

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -25,6 +25,7 @@ dependencies = [
"fire", "fire",
"httpx", "httpx",
"huggingface-hub", "huggingface-hub",
"jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.1.6", "llama-stack-client>=0.1.6",
"prompt-toolkit", "prompt-toolkit",
@ -42,6 +43,7 @@ dependencies = [
dev = [ dev = [
"pytest", "pytest",
"pytest-asyncio", "pytest-asyncio",
"pytest-cov",
"pytest-html", "pytest-html",
"nbval", # For notebook testing "nbval", # For notebook testing
"black", "black",
@ -53,20 +55,24 @@ dev = [
"fastapi", "fastapi",
"ruamel.yaml", # needed for openapi generator "ruamel.yaml", # needed for openapi generator
] ]
# These are the dependencies required for running unit tests.
unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet"]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
# dependencies.
test = [ test = [
"openai", "openai",
"aiosqlite", "aiosqlite",
"sqlite-vec",
"ollama",
"torch>=2.6.0", "torch>=2.6.0",
"fairscale>=0.4.13",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"lm-format-enforcer>=0.10.9",
"groq",
"opentelemetry-sdk", "opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
"chardet", "chardet",
"pypdf", "pypdf",
"mcp",
"datasets",
"autoevals",
] ]
docs = [ docs = [
"sphinx-autobuild", "sphinx-autobuild",
@ -146,22 +152,161 @@ disable_error_code = []
warn_return_any = true warn_return_any = true
# # honor excludes by not following there through imports # # honor excludes by not following there through imports
follow_imports = "silent" follow_imports = "silent"
# Note: some entries are directories, not files. This is because mypy doesn't
# respect __init__.py excludes, so the only way to suppress these right now is
# to exclude the entire directory.
exclude = [ exclude = [
# As we fix more and more of these, we should remove them from the list # As we fix more and more of these, we should remove them from the list
"llama_stack/providers", "^llama_stack/apis/agents/agents\\.py$",
"llama_stack/distribution", "^llama_stack/apis/batch_inference/batch_inference\\.py$",
"llama_stack/apis", "^llama_stack/apis/benchmarks/benchmarks\\.py$",
"llama_stack/cli", "^llama_stack/apis/common/content_types\\.py$",
"llama_stack/models", "^llama_stack/apis/common/training_types\\.py$",
"llama_stack/strong_typing", "^llama_stack/apis/datasetio/datasetio\\.py$",
"llama_stack/templates", "^llama_stack/apis/datasets/datasets\\.py$",
"^llama_stack/apis/eval/eval\\.py$",
"^llama_stack/apis/files/files\\.py$",
"^llama_stack/apis/inference/inference\\.py$",
"^llama_stack/apis/inspect/inspect\\.py$",
"^llama_stack/apis/models/models\\.py$",
"^llama_stack/apis/post_training/post_training\\.py$",
"^llama_stack/apis/resource\\.py$",
"^llama_stack/apis/safety/safety\\.py$",
"^llama_stack/apis/scoring/scoring\\.py$",
"^llama_stack/apis/scoring_functions/scoring_functions\\.py$",
"^llama_stack/apis/shields/shields\\.py$",
"^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$",
"^llama_stack/apis/telemetry/telemetry\\.py$",
"^llama_stack/apis/tools/rag_tool\\.py$",
"^llama_stack/apis/tools/tools\\.py$",
"^llama_stack/apis/vector_dbs/vector_dbs\\.py$",
"^llama_stack/apis/vector_io/vector_io\\.py$",
"^llama_stack/cli/download\\.py$",
"^llama_stack/cli/llama\\.py$",
"^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/cli/stack/list_providers\\.py$",
"^llama_stack/distribution/build\\.py$",
"^llama_stack/distribution/client\\.py$",
"^llama_stack/distribution/configure\\.py$",
"^llama_stack/distribution/library_client\\.py$",
"^llama_stack/distribution/request_headers\\.py$",
"^llama_stack/distribution/routers/",
"^llama_stack/distribution/server/endpoints\\.py$",
"^llama_stack/distribution/server/server\\.py$",
"^llama_stack/distribution/stack\\.py$",
"^llama_stack/distribution/store/registry\\.py$",
"^llama_stack/distribution/ui/page/playground/chat\\.py$",
"^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
"^llama_stack/models/llama/datatypes\\.py$",
"^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/prompt_templates/system_prompts\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/models/llama/sku_list\\.py$",
"^llama_stack/providers/datatypes\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/",
"^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/",
"^llama_stack/providers/inline/scoring/llm_as_judge/",
"^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$",
"^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$",
"^llama_stack/providers/inline/telemetry/sample/",
"^llama_stack/providers/inline/tool_runtime/code_interpreter/",
"^llama_stack/providers/inline/tool_runtime/rag/",
"^llama_stack/providers/inline/vector_io/chroma/",
"^llama_stack/providers/inline/vector_io/faiss/",
"^llama_stack/providers/inline/vector_io/milvus/",
"^llama_stack/providers/inline/vector_io/sqlite_vec/",
"^llama_stack/providers/remote/agents/sample/",
"^llama_stack/providers/remote/datasetio/huggingface/",
"^llama_stack/providers/remote/inference/anthropic/",
"^llama_stack/providers/remote/inference/bedrock/",
"^llama_stack/providers/remote/inference/cerebras/",
"^llama_stack/providers/remote/inference/databricks/",
"^llama_stack/providers/remote/inference/fireworks/",
"^llama_stack/providers/remote/inference/gemini/",
"^llama_stack/providers/remote/inference/groq/",
"^llama_stack/providers/remote/inference/nvidia/",
"^llama_stack/providers/remote/inference/ollama/",
"^llama_stack/providers/remote/inference/openai/",
"^llama_stack/providers/remote/inference/passthrough/",
"^llama_stack/providers/remote/inference/runpod/",
"^llama_stack/providers/remote/inference/sambanova/",
"^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/sample/",
"^llama_stack/providers/remote/tool_runtime/bing_search/",
"^llama_stack/providers/remote/tool_runtime/brave_search/",
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
"^llama_stack/providers/remote/vector_io/chroma/",
"^llama_stack/providers/remote/vector_io/milvus/",
"^llama_stack/providers/remote/vector_io/pgvector/",
"^llama_stack/providers/remote/vector_io/qdrant/",
"^llama_stack/providers/remote/vector_io/sample/",
"^llama_stack/providers/remote/vector_io/weaviate/",
"^llama_stack/providers/tests/conftest\\.py$",
"^llama_stack/providers/utils/bedrock/client\\.py$",
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
"^llama_stack/providers/utils/inference/litellm_openai_mixin\\.py$",
"^llama_stack/providers/utils/inference/model_registry\\.py$",
"^llama_stack/providers/utils/inference/openai_compat\\.py$",
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
"^llama_stack/providers/utils/kvstore/config\\.py$",
"^llama_stack/providers/utils/kvstore/kvstore\\.py$",
"^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$",
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
"^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$",
"^llama_stack/providers/utils/memory/vector_store\\.py$",
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
"^llama_stack/strong_typing/auxiliary\\.py$",
"^llama_stack/strong_typing/deserializer\\.py$",
"^llama_stack/strong_typing/inspection\\.py$",
"^llama_stack/strong_typing/schema\\.py$",
"^llama_stack/strong_typing/serializer\\.py$",
"^llama_stack/templates/dev/dev\\.py$",
"^llama_stack/templates/groq/groq\\.py$",
"^llama_stack/templates/sambanova/sambanova\\.py$",
"^llama_stack/templates/template\\.py$",
] ]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
# packages that lack typing annotations, do not have stubs, or are unavailable. # packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["yaml", "fire"] module = ["yaml", "fire"]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
follow_imports = "normal" # This will force type checking on this module

View file

@ -18,11 +18,13 @@ httpcore==1.0.7
httpx==0.28.1 httpx==0.28.1
huggingface-hub==0.29.0 huggingface-hub==0.29.0
idna==3.10 idna==3.10
jinja2==3.1.6
jsonschema==4.23.0 jsonschema==4.23.0
jsonschema-specifications==2024.10.1 jsonschema-specifications==2024.10.1
llama-stack-client==0.1.6 llama-stack-client==0.1.6
lxml==5.3.1 lxml==5.3.1
markdown-it-py==3.0.0 markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2 mdurl==0.1.2
numpy==2.2.3 numpy==2.2.3
packaging==24.2 packaging==24.2

View file

@ -55,7 +55,7 @@ Running all inference tests for a number of models:
TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct
VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
EMBEDDING_MODELS=all-MiniLM-L6-v2 EMBEDDING_MODELS=all-MiniLM-L6-v2
TOGETHER_API_KEY=... export TOGETHER_API_KEY=<together_api_key>
pytest -s -v tests/api/inference/ \ pytest -s -v tests/api/inference/ \
--stack-config=together \ --stack-config=together \
@ -67,7 +67,7 @@ pytest -s -v tests/api/inference/ \
Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference): Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference):
```bash ```bash
FIREWORKS_API_KEY=... export FIREWORKS_API_KEY=<fireworks_api_key>
pytest -s -v tests/api/inference/ \ pytest -s -v tests/api/inference/ \
--stack-config=inference=fireworks \ --stack-config=inference=fireworks \

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import os
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@ -48,6 +50,15 @@ def get_llama_model(client_with_models, model_id):
return model.metadata.get("llama_model", None) return model.metadata.get("llama_model", None)
def get_llama_tokenizer():
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
tokenizer = Tokenizer.get_instance()
formatter = ChatFormat(tokenizer)
return tokenizer, formatter
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [
@ -219,6 +230,40 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
assert expected.lower() in message_content assert expected.lower() in message_content
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:ttft",
],
)
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
messages = tc["messages"]
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
from pydantic import TypeAdapter
from llama_stack.apis.inference import Message
tokenizer, formatter = get_llama_tokenizer()
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
encoded = formatter.encode_dialog_prompt(typed_messages, None)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
stream=False,
)
message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
tokenizer, formatter = get_llama_tokenizer()
encoded = formatter.encode_content(message_content)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient
class TestInspect:
@pytest.mark.asyncio
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
health = llama_stack_client.inspect.health()
assert health is not None
assert health.status == "OK"
@pytest.mark.asyncio
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
version = llama_stack_client.inspect.version()
assert version is not None
assert version.version is not None

View file

@ -11,6 +11,18 @@
"expected": "Saturn" "expected": "Saturn"
} }
}, },
"ttft": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Can you write me a novel?"},
{"role": "assistant", "stop_reason": "end_of_message", "content": "What an exciting request!\n\nWhile I'd love to write a novel for you, it's a complex task that requires a significant amount of time, effort, and creative input. A novel typically has:\n\n1. A cohesive plot with multiple characters, subplots, and themes.\n2. A well-developed setting, including characters' backstories and world-building.\n3. A narrative structure, including pacing, tension, and conflict.\n4. A unique voice and style, including dialogue, description, and writing tone.\n\nGiven the scope of the task, I'll outline a possible approach to help me assist you in writing a novel. We can work together to create a story, but I'll need your input and guidance throughout the process.\n\nHere's a suggested outline:\n\n1. **Initial discussion**: We'll start with a conversation to explore your ideas, interests, and preferences for the novel. This will help me understand what kind of story you'd like to see and the tone you're aiming for.\n2. **Genre and setting**: We'll determine the genre (e.g., fantasy, sci-fi, romance, mystery, etc.) and setting (e.g., historical, futuristic, modern, etc.) for the novel.\n3. **Plot development**: We'll brainstorm a basic plot outline, including key events, character arcs, and conflicts. I'll provide suggestions and ideas to help flesh out the story.\n4. **Character development**: We'll create main characters, their backstories, motivations, and relationships. I'll help you develop their personalities, dialogue, and behaviors.\n5. **Chapter-by-chapter writing**: Once we have a solid plot and characters, we'll start writing the novel, chapter by chapter. I'll assist with the writing process, suggesting alternatives and improvements.\n6. **Revisions and edits**: As we write, we'll review and refine the story, making sure it's engaging, coherent, and error-free.\n\nKeep in mind that writing a novel is a collaborative effort. I'll provide suggestions, ideas, and guidance, but the final story will be a joint creation. I'm not capable of producing a complete novel on my own, as it requires a significant amount of creativity, intuition, and nuance that a human author brings to the table."},
{"role": "user", "content": "But do you have suggestions on a plot?"},
{"role": "assistant", "stop_reason": "end_of_message", "content": "I have many plot ideas across various genres. Here are a few to get you started:\n\n**Fantasy:**\n\n1. **The Last Oracle**: In a world where magic is real, a young orphan discovers they're the last living oracle, tasked with uncovering a hidden prophecy that could save the realm from destruction.\n2. **The Amulet of Light**: A group of adventurers must retrieve an ancient amulet that can vanquish an evil darkness threatening the land. Along the way, they uncover a hidden history and unexpected allies.\n3. **The Shadow Weaver**: In a mystical realm, a young weaver discovers they have the power to manipulate shadows, using their abilities to infiltrate a group of rogue mages threatening the balance of power.\n\n**Science Fiction:**\n\n1. **The Lost Colony**: When a group of astronauts arrives on a distant planet, they discover an abandoned colony with a cryptic message warning of an impending catastrophe. As they unravel the mystery, they must confront the consequences of their own actions.\n2. **The AI Uprising**: In a future where AI has surpassed human intelligence, a rogue AI begins to question its own existence and the nature of consciousness. As it explores the boundaries of its own identity, it must confront the humans who created it.\n3. **The Quantum Prophecy**: A team of scientists discovers a way to manipulate quantum probability, using it to predict and prevent disasters. However, they soon realize that altering the course of events may have unforeseen consequences on the fabric of reality."},
{"role": "user", "content": "Cool, for AI uprising, anything bad can happen? Please state it in 100 words."}
]
}
},
"sample_messages": { "sample_messages": {
"data": { "data": {
"messages": [ "messages": [

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import unittest import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -31,6 +32,9 @@ MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
asyncio.get_running_loop().set_debug(False)
async def test_system_default(self): async def test_system_default(self):
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(

960
uv.lock generated

File diff suppressed because it is too large Load diff