mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
e49bcd46fe
90 changed files with 3142 additions and 586 deletions
20
.github/workflows/unit-tests.yml
vendored
20
.github/workflows/unit-tests.yml
vendored
|
@ -8,29 +8,37 @@ on:
|
|||
jobs:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
- "3.13"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10.16'
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10.16'
|
||||
python-version: ${{ matrix.python }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
uv run -p 3.10.16 --with . --with ".[dev]" --with ".[test]" pytest -s -v tests/unit/ --junitxml=pytest-report.xml
|
||||
uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
name: test-results-${{ matrix.python }}
|
||||
path: |
|
||||
.pytest_cache/
|
||||
pytest-report.xml
|
||||
pytest-report-${{ matrix.python }}.xml
|
||||
retention-days: 7
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -21,3 +21,4 @@ docs/src
|
|||
pyrightconfig.json
|
||||
venv/
|
||||
pytest-report.xml
|
||||
.coverage
|
||||
|
|
|
@ -159,8 +159,7 @@ uv run sphinx-autobuild source build/html --write-all
|
|||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
||||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
|
|
@ -427,6 +427,7 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
@ -448,7 +449,40 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn"
|
||||
],
|
||||
"open-benchmark": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn"
|
||||
|
|
|
@ -71,7 +71,6 @@ services:
|
|||
condition: service_healthy
|
||||
- vllm-${VLLM_SAFETY_MODEL:+safety}:
|
||||
condition: service_healthy
|
||||
# image: llamastack/distribution-remote-vllm
|
||||
image: llamastack/distribution-remote-vllm:test-0.0.52rc3
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
|
|
189
docs/_static/llama-stack-spec.html
vendored
189
docs/_static/llama-stack-spec.html
vendored
|
@ -363,6 +363,37 @@
|
|||
}
|
||||
},
|
||||
"/v1/agents": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListAgentsResponse.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListAgentsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "List all agents.",
|
||||
"parameters": []
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -609,6 +640,47 @@
|
|||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "An Agent of the agent.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Agent"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "Describe an agent by its ID.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "ID of the agent.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -2276,6 +2348,49 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}/sessions": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListAgentSessionsResponse.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListAgentSessionsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "List all session(s) of a given agent.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the agent to list sessions for.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/eval/benchmarks": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -6565,6 +6680,28 @@
|
|||
"title": "ScoringResult",
|
||||
"description": "A scoring result for a single row."
|
||||
},
|
||||
"Agent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"agent_config": {
|
||||
"$ref": "#/components/schemas/AgentConfig"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"agent_id",
|
||||
"agent_config",
|
||||
"created_at"
|
||||
],
|
||||
"title": "Agent"
|
||||
},
|
||||
"Session": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -7907,6 +8044,38 @@
|
|||
],
|
||||
"title": "ToolInvocationResult"
|
||||
},
|
||||
"ListAgentSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Session"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListAgentSessionsResponse"
|
||||
},
|
||||
"ListAgentsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Agent"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListAgentsResponse"
|
||||
},
|
||||
"BucketResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9321,21 +9490,11 @@
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"tool_responses": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolResponse"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolResponse"
|
||||
},
|
||||
"description": "The tool call responses to resume the turn with."
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean",
|
||||
|
|
131
docs/_static/llama-stack-spec.yaml
vendored
131
docs/_static/llama-stack-spec.yaml
vendored
|
@ -238,6 +238,28 @@ paths:
|
|||
$ref: '#/components/schemas/CompletionRequest'
|
||||
required: true
|
||||
/v1/agents:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all agents.
|
||||
parameters: []
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -410,6 +432,34 @@ paths:
|
|||
$ref: '#/components/schemas/CreateUploadSessionRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: An Agent of the agent.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: Describe an agent by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: ID of the agent.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -1528,6 +1578,36 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/agents/{agent_id}/sessions:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentSessionsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentSessionsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all session(s) of a given agent.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to list sessions for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/eval/benchmarks:
|
||||
get:
|
||||
responses:
|
||||
|
@ -4549,6 +4629,22 @@ components:
|
|||
- aggregated_results
|
||||
title: ScoringResult
|
||||
description: A scoring result for a single row.
|
||||
Agent:
|
||||
type: object
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
agent_config:
|
||||
$ref: '#/components/schemas/AgentConfig'
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
additionalProperties: false
|
||||
required:
|
||||
- agent_id
|
||||
- agent_config
|
||||
- created_at
|
||||
title: Agent
|
||||
Session:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5385,6 +5481,28 @@ components:
|
|||
required:
|
||||
- content
|
||||
title: ToolInvocationResult
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Session'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentSessionsResponse
|
||||
ListAgentsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentsResponse
|
||||
BucketResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6287,16 +6405,11 @@ components:
|
|||
type: object
|
||||
properties:
|
||||
tool_responses:
|
||||
oneOf:
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponse'
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponseMessage'
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponse'
|
||||
description: >-
|
||||
The tool call responses to resume the turn with. NOTE: ToolResponseMessage
|
||||
will be deprecated. Use ToolResponse.
|
||||
The tool call responses to resume the turn with.
|
||||
stream:
|
||||
type: boolean
|
||||
description: Whether to stream the response.
|
||||
|
|
|
@ -1267,7 +1267,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -1279,7 +1278,7 @@
|
|||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" model_id=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -1640,7 +1639,7 @@
|
|||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" instructions=\"You are a helpful assistant. Use websearch tool to help answer questions.\",\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"user_prompts = [\n",
|
||||
|
|
|
@ -1,9 +1 @@
|
|||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
|
||||
|
||||
Please install the following packages before running the script:
|
||||
|
||||
```
|
||||
pip install fire PyYAML
|
||||
```
|
||||
|
||||
Then simply run `sh run_openapi_generator.sh`
|
||||
|
|
|
@ -23,9 +23,12 @@ In this example, we will show you how to:
|
|||
|
||||
##### Building a Search Agent
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
|
@ -33,7 +36,7 @@ agent = Agent(
|
|||
tools=["builtin::websearch"],
|
||||
)
|
||||
user_prompts = [
|
||||
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
|
||||
"Which teams played in the NBA Western Conference Finals of 2024. Search the web for the answer.",
|
||||
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
||||
"What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.",
|
||||
]
|
||||
|
|
|
@ -33,6 +33,8 @@ Can be set to any of the following log levels:
|
|||
|
||||
The default global log level is `info`. `all` sets the log level for all components.
|
||||
|
||||
A user can also set `LLAMA_STACK_LOG_FILE` which will pipe the logs to the specified path as well as to the terminal. An example would be: `export LLAMA_STACK_LOG_FILE=server.log`
|
||||
|
||||
### Llama Stack Build
|
||||
|
||||
In order to build your own distribution, we recommend you clone the `llama-stack` repository.
|
||||
|
|
|
@ -40,7 +40,6 @@ The following models are available by default:
|
|||
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||
|
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# llama (server-side) CLI Reference
|
||||
|
||||
The `llama` CLI tool helps you setup and use the Llama Stack. It should be available on your path after installing the `llama-stack` package.
|
||||
The `llama` CLI tool helps you set up and use the Llama Stack. The CLI is available on your path after installing the `llama-stack` package.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -27,9 +27,9 @@ You have two ways to install Llama Stack:
|
|||
|
||||
|
||||
## `llama` subcommands
|
||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||
2. `model`: Lists available models and their properties.
|
||||
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](../../distributions/building_distro).
|
||||
1. `download`: Supports downloading models from Meta or Hugging Face. [Downloading models](#downloading-models)
|
||||
2. `model`: Lists available models and their properties. [Understanding models](#understand-the-models)
|
||||
3. `stack`: Allows you to build a stack using the `llama stack` distribution and run a Llama Stack server. You can read more about how to build a Llama Stack distribution in the [Build your own Distribution](../../distributions/building_distro) documentation.
|
||||
|
||||
### Sample Usage
|
||||
|
||||
|
@ -117,7 +117,7 @@ You should see a table like this:
|
|||
+----------------------------------+------------------------------------------+----------------+
|
||||
```
|
||||
|
||||
To download models, you can use the llama download command.
|
||||
To download models, you can use the `llama download` command.
|
||||
|
||||
### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
|
||||
|
||||
|
@ -191,7 +191,7 @@ You should see a table like this:
|
|||
The `llama model` command helps you explore the model’s interface.
|
||||
|
||||
1. `download`: Download the model from different sources. (meta, huggingface)
|
||||
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
|
||||
2. `list`: Lists all the models available for download with hardware requirements for deploying the models.
|
||||
3. `prompt-format`: Show llama model message formats.
|
||||
4. `describe`: Describes all the properties of the model.
|
||||
|
||||
|
@ -262,13 +262,12 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
|||

|
||||
|
||||
|
||||
|
||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||
|
||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||
|
||||
### Remove model
|
||||
You can run `llama model remove` to remove unecessary model:
|
||||
You can run `llama model remove` to remove an unnecessary model:
|
||||
|
||||
```
|
||||
llama model remove -m Llama-Guard-3-8B-int8
|
||||
|
|
|
@ -40,7 +40,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
ollama run llama3.2:3b-instruct-fp16 --keepalive -1m
|
||||
```
|
||||
**Note**:
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43)
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py)
|
||||
- `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again.
|
||||
|
||||
---
|
||||
|
|
|
@ -234,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
@ -353,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
|
||||
tool_responses: List[ToolResponse]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -432,7 +449,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
@ -443,7 +460,6 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
|
@ -541,3 +557,32 @@ class Agents(Protocol):
|
|||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
|
@ -39,7 +39,7 @@ from llama_stack.distribution.resolver import InvalidProviderError
|
|||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
|
@ -226,7 +226,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
|
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -95,7 +95,7 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -104,7 +104,7 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -112,7 +112,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
|
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
|
|
|
@ -7,14 +7,14 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
|
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
|||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = _provider_data_var.get()
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
|
@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=[info.routing_table_api.value],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
|
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
) -> List[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -156,8 +236,9 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
"core",
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
@ -206,10 +287,47 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -239,10 +357,41 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
@ -30,7 +28,7 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -40,6 +38,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
|
@ -47,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
TelemetryAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
@ -118,69 +118,24 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -188,8 +143,7 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -230,13 +184,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception("Error executing endpoint %s", method, route)
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
@ -266,7 +222,7 @@ class TracingMiddleware:
|
|||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
|
@ -439,8 +395,6 @@ def main():
|
|||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
@ -471,6 +425,8 @@ def main():
|
|||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
|
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
|
@ -20,14 +20,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
|
@ -12,6 +12,7 @@ from typing import Dict
|
|||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
@ -96,12 +97,13 @@ class CustomRichHandler(RichHandler):
|
|||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int]) -> None:
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels.
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
|
@ -116,6 +118,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
|
@ -125,17 +149,7 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use our custom handler class
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
|
@ -143,19 +157,24 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
# Ensure third-party libraries follow the root log level
|
||||
for _, logger in logging.root.manager.loggerDict.items():
|
||||
if isinstance(logger, logging.Logger):
|
||||
logger.setLevel(root_level)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
"""
|
||||
|
@ -176,7 +195,9 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
setup_logging(_category_levels)
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
|
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=tool_responses,
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
@ -390,7 +382,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
@ -508,7 +500,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -685,7 +677,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
@ -169,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
url = str(dataset_info.dataset_def.url.uri)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
|
@ -3,11 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
|
|
|
@ -24,10 +24,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> LlamaStackClient:
|
||||
def _get_client(self) -> AsyncLlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return LlamaStackClient(
|
||||
return AsyncLlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
|
@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"content": content,
|
||||
"sampling_params": sampling_params,
|
||||
|
@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.completion(**params)
|
||||
return await client.inference.completion(**json_params)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
|
@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.chat_completion(**params)
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
response = response.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
response["metrics"] = []
|
||||
|
||||
return convert_to_pydantic(ChatCompletionResponse, response)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return client.inference.embeddings(
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
json_params = {}
|
||||
for key, value in request_params.items():
|
||||
json_input = convert_pydantic_to_json_value(value)
|
||||
if isinstance(json_input, dict):
|
||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||
elif isinstance(json_input, list):
|
||||
json_input = [x for x in json_input if x is not None]
|
||||
new_input = []
|
||||
for x in json_input:
|
||||
if isinstance(x, dict):
|
||||
x = {k: v for k, v in x.items() if v is not None}
|
||||
new_input.append(x)
|
||||
json_input = new_input
|
||||
|
||||
json_params[key] = json_input
|
||||
|
||||
return json_params
|
||||
|
|
|
@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:}",
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from together import Together
|
||||
from together import AsyncTogether
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -91,35 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
client = self._get_client()
|
||||
r = await client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
client = await self._get_client()
|
||||
stream = await client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
|
@ -184,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
stream = await client.chat.completions.create(**params)
|
||||
else:
|
||||
stream = await client.completions.create(**params)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
|
@ -240,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
r = self._get_client().embeddings.create(
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
|
|
@ -615,6 +615,14 @@ def convert_tool_call(
|
|||
return valid_tool_call
|
||||
|
||||
|
||||
PYTHON_TYPE_TO_LITELLM_TYPE = {
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"str": "string",
|
||||
}
|
||||
|
||||
|
||||
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||
"""
|
||||
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
|
||||
|
@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = {"type": param.param_type}
|
||||
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
|
|
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
|
|||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 8):
|
||||
|
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
|
|||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||
BACKGROUND_LOGGER = None
|
||||
|
||||
|
||||
|
@ -51,7 +53,7 @@ class BackgroundLogger:
|
|||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
log.error("Log queue is full, dropping event")
|
||||
logger.error("Log queue is full, dropping event")
|
||||
|
||||
def _process_logs(self):
|
||||
while True:
|
||||
|
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(TelemetryHandler())
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid(16)
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
CURRENT_TRACE_CONTEXT = context
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
logger.debug("No trace context to end")
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
CURRENT_TRACE_CONTEXT.set(None)
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
|
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
||||
|
@ -218,16 +221,22 @@ class SpanContextManager:
|
|||
|
||||
def __enter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.span:
|
||||
|
@ -237,16 +246,22 @@ class SpanContextManager:
|
|||
|
||||
async def __aenter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
@wraps(func)
|
||||
|
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
|||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
return None
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
|
@ -120,16 +120,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -178,16 +178,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -132,16 +132,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -126,16 +126,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -5,7 +5,7 @@ distribution_spec:
|
|||
inference:
|
||||
- remote::ollama
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
|
|
|
@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::ollama"],
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
|
@ -43,10 +43,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider_sqlite = Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
vector_io_provider_faiss = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
|
@ -96,7 +96,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
@ -104,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
|
|
@ -17,10 +17,13 @@ providers:
|
|||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -17,10 +17,13 @@ providers:
|
|||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .open_benchmark import get_distribution_template # noqa: F401
|
300
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
300
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
|
@ -0,0 +1,300 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkInput,
|
||||
DatasetInput,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
"openai",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/gpt-4o",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/claude-3-5-sonnet-latest",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"gemini",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/gemini-1.5-flash",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"groq",
|
||||
[],
|
||||
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"together",
|
||||
[],
|
||||
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
|
||||
),
|
||||
]
|
||||
inference_providers = []
|
||||
available_models = {}
|
||||
for provider_id, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=f"remote::{provider_id}",
|
||||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": [p.provider_type for p in inference_providers],
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
name = "open-benchmark"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
db="${env.PGVECTOR_DB:}",
|
||||
user="${env.PGVECTOR_USER:}",
|
||||
password="${env.PGVECTOR_PASSWORD:}",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models) + [
|
||||
ModelInput(
|
||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||
provider_id="groq",
|
||||
provider_model_id="groq/llama-3.3-70b-versatile",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ModelInput(
|
||||
model_id="meta-llama/Llama-3.1-405B-Instruct",
|
||||
provider_id="together",
|
||||
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
|
||||
default_datasets = [
|
||||
DatasetInput(
|
||||
dataset_id="simpleqa",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||
metadata={
|
||||
"path": "llamastack/simpleqa",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="mmlu_cot",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/mmlu_cot",
|
||||
"name": "all",
|
||||
"split": "test",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="gpqa_cot",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/gpqa_0shot_cot",
|
||||
"name": "gpqa_main",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="math_500",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||
metadata={
|
||||
"path": "llamastack/math_500",
|
||||
"split": "test",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
default_benchmarks = [
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-simpleqa",
|
||||
dataset_id="simpleqa",
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-mmlu-cot",
|
||||
dataset_id="mmlu_cot",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-gpqa-cot",
|
||||
dataset_id="gpqa_cot",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-math-500",
|
||||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
]
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Distribution for running open benchmarks",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": inference_providers,
|
||||
"vector_io": vector_io_providers,
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
default_datasets=default_datasets,
|
||||
default_benchmarks=default_benchmarks,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"TOGETHER_API_KEY": (
|
||||
"",
|
||||
"Together API Key",
|
||||
),
|
||||
"OPENAI_API_KEY": (
|
||||
"",
|
||||
"OpenAI API Key",
|
||||
),
|
||||
"GEMINI_API_KEY": (
|
||||
"",
|
||||
"Gemini API Key",
|
||||
),
|
||||
"ANTHROPIC_API_KEY": (
|
||||
"",
|
||||
"Anthropic API Key",
|
||||
),
|
||||
"GROQ_API_KEY": (
|
||||
"",
|
||||
"Groq API Key",
|
||||
),
|
||||
},
|
||||
)
|
|
@ -33,12 +33,12 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
|
@ -62,14 +62,14 @@ providers:
|
|||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -114,18 +114,13 @@ providers:
|
|||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: anthropic/claude-3-5-sonnet-latest
|
||||
provider_id: anthropic
|
||||
|
@ -141,66 +136,95 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
vector_dbs: []
|
||||
datasets:
|
||||
- dataset_id: simpleqa
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/simpleqa
|
||||
metadata:
|
||||
path: llamastack/simpleqa
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_id: mmlu_cot
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
|
||||
metadata:
|
||||
path: llamastack/mmlu_cot
|
||||
name: all
|
||||
split: test
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_id: gpqa_cot
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
|
||||
metadata:
|
||||
path: llamastack/gpqa_0shot_cot
|
||||
name: gpqa_main
|
||||
split: train
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/simpleqa
|
||||
metadata:
|
||||
path: llamastack/simpleqa
|
||||
split: train
|
||||
dataset_id: simpleqa
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
|
||||
metadata:
|
||||
path: llamastack/mmlu_cot
|
||||
name: all
|
||||
split: test
|
||||
dataset_id: mmlu_cot
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
|
||||
metadata:
|
||||
path: llamastack/gpqa_0shot_cot
|
||||
name: gpqa_main
|
||||
split: train
|
||||
dataset_id: gpqa_cot
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/math_500
|
||||
metadata:
|
||||
path: llamastack/math_500
|
||||
split: test
|
||||
dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- benchmark_id: meta-reference-simpleqa
|
||||
dataset_id: simpleqa
|
||||
scoring_functions: ["llm-as-judge::405b-simpleqa"]
|
||||
- benchmark_id: meta-reference-mmlu-cot
|
||||
dataset_id: mmlu_cot
|
||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||
- benchmark_id: meta-reference-gpqa-cot
|
||||
dataset_id: gpqa_cot
|
||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||
- dataset_id: simpleqa
|
||||
scoring_functions:
|
||||
- llm-as-judge::405b-simpleqa
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-simpleqa
|
||||
- dataset_id: mmlu_cot
|
||||
scoring_functions:
|
||||
- basic::regex_parser_multiple_choice_answer
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-mmlu-cot
|
||||
- dataset_id: gpqa_cot
|
||||
scoring_functions:
|
||||
- basic::regex_parser_multiple_choice_answer
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-gpqa-cot
|
||||
- dataset_id: math_500
|
||||
scoring_functions:
|
||||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
|
@ -14,7 +14,9 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
Api,
|
||||
BenchmarkInput,
|
||||
BuildConfig,
|
||||
DatasetInput,
|
||||
DistributionSpec,
|
||||
ModelInput,
|
||||
Provider,
|
||||
|
@ -28,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
|
|||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
||||
def get_model_registry(
|
||||
available_models: Dict[str, List[ProviderModelEntry]],
|
||||
) -> List[ModelInput]:
|
||||
models = []
|
||||
for provider_id, entries in available_models.items():
|
||||
for entry in entries:
|
||||
|
@ -56,6 +60,8 @@ class RunConfigSettings(BaseModel):
|
|||
default_models: Optional[List[ModelInput]] = None
|
||||
default_shields: Optional[List[ShieldInput]] = None
|
||||
default_tool_groups: Optional[List[ToolGroupInput]] = None
|
||||
default_datasets: Optional[List[DatasetInput]] = None
|
||||
default_benchmarks: Optional[List[BenchmarkInput]] = None
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
|
@ -113,6 +119,8 @@ class RunConfigSettings(BaseModel):
|
|||
models=self.default_models or [],
|
||||
shields=self.default_shields or [],
|
||||
tool_groups=self.default_tool_groups or [],
|
||||
datasets=self.default_datasets or [],
|
||||
benchmarks=self.default_benchmarks or [],
|
||||
)
|
||||
|
||||
|
||||
|
@ -187,7 +195,7 @@ class DistributionTemplate(BaseModel):
|
|||
default_models.append(
|
||||
DefaultModel(
|
||||
model_id=model_entry.provider_model_id,
|
||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
||||
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
|
177
pyproject.toml
177
pyproject.toml
|
@ -25,6 +25,7 @@ dependencies = [
|
|||
"fire",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.1.6",
|
||||
"prompt-toolkit",
|
||||
|
@ -42,6 +43,7 @@ dependencies = [
|
|||
dev = [
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-html",
|
||||
"nbval", # For notebook testing
|
||||
"black",
|
||||
|
@ -53,20 +55,24 @@ dev = [
|
|||
"fastapi",
|
||||
"ruamel.yaml", # needed for openapi generator
|
||||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet"]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||
# dependencies.
|
||||
test = [
|
||||
"openai",
|
||||
"aiosqlite",
|
||||
"sqlite-vec",
|
||||
"ollama",
|
||||
"torch>=2.6.0",
|
||||
"fairscale>=0.4.13",
|
||||
"torchvision>=0.21.0",
|
||||
"lm-format-enforcer>=0.10.9",
|
||||
"groq",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"mcp",
|
||||
"datasets",
|
||||
"autoevals",
|
||||
]
|
||||
docs = [
|
||||
"sphinx-autobuild",
|
||||
|
@ -146,22 +152,161 @@ disable_error_code = []
|
|||
warn_return_any = true
|
||||
# # honor excludes by not following there through imports
|
||||
follow_imports = "silent"
|
||||
# Note: some entries are directories, not files. This is because mypy doesn't
|
||||
# respect __init__.py excludes, so the only way to suppress these right now is
|
||||
# to exclude the entire directory.
|
||||
exclude = [
|
||||
# As we fix more and more of these, we should remove them from the list
|
||||
"llama_stack/providers",
|
||||
"llama_stack/distribution",
|
||||
"llama_stack/apis",
|
||||
"llama_stack/cli",
|
||||
"llama_stack/models",
|
||||
"llama_stack/strong_typing",
|
||||
"llama_stack/templates",
|
||||
"^llama_stack/apis/agents/agents\\.py$",
|
||||
"^llama_stack/apis/batch_inference/batch_inference\\.py$",
|
||||
"^llama_stack/apis/benchmarks/benchmarks\\.py$",
|
||||
"^llama_stack/apis/common/content_types\\.py$",
|
||||
"^llama_stack/apis/common/training_types\\.py$",
|
||||
"^llama_stack/apis/datasetio/datasetio\\.py$",
|
||||
"^llama_stack/apis/datasets/datasets\\.py$",
|
||||
"^llama_stack/apis/eval/eval\\.py$",
|
||||
"^llama_stack/apis/files/files\\.py$",
|
||||
"^llama_stack/apis/inference/inference\\.py$",
|
||||
"^llama_stack/apis/inspect/inspect\\.py$",
|
||||
"^llama_stack/apis/models/models\\.py$",
|
||||
"^llama_stack/apis/post_training/post_training\\.py$",
|
||||
"^llama_stack/apis/resource\\.py$",
|
||||
"^llama_stack/apis/safety/safety\\.py$",
|
||||
"^llama_stack/apis/scoring/scoring\\.py$",
|
||||
"^llama_stack/apis/scoring_functions/scoring_functions\\.py$",
|
||||
"^llama_stack/apis/shields/shields\\.py$",
|
||||
"^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$",
|
||||
"^llama_stack/apis/telemetry/telemetry\\.py$",
|
||||
"^llama_stack/apis/tools/rag_tool\\.py$",
|
||||
"^llama_stack/apis/tools/tools\\.py$",
|
||||
"^llama_stack/apis/vector_dbs/vector_dbs\\.py$",
|
||||
"^llama_stack/apis/vector_io/vector_io\\.py$",
|
||||
"^llama_stack/cli/download\\.py$",
|
||||
"^llama_stack/cli/llama\\.py$",
|
||||
"^llama_stack/cli/stack/_build\\.py$",
|
||||
"^llama_stack/cli/stack/list_providers\\.py$",
|
||||
"^llama_stack/distribution/build\\.py$",
|
||||
"^llama_stack/distribution/client\\.py$",
|
||||
"^llama_stack/distribution/configure\\.py$",
|
||||
"^llama_stack/distribution/library_client\\.py$",
|
||||
"^llama_stack/distribution/request_headers\\.py$",
|
||||
"^llama_stack/distribution/routers/",
|
||||
"^llama_stack/distribution/server/endpoints\\.py$",
|
||||
"^llama_stack/distribution/server/server\\.py$",
|
||||
"^llama_stack/distribution/stack\\.py$",
|
||||
"^llama_stack/distribution/store/registry\\.py$",
|
||||
"^llama_stack/distribution/ui/page/playground/chat\\.py$",
|
||||
"^llama_stack/distribution/utils/exec\\.py$",
|
||||
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
||||
"^llama_stack/models/llama/datatypes\\.py$",
|
||||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^llama_stack/models/llama/llama3/prompt_templates/system_prompts\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||
"^llama_stack/models/llama/sku_list\\.py$",
|
||||
"^llama_stack/providers/datatypes\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
|
||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
|
||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
"^llama_stack/providers/inline/safety/prompt_guard/",
|
||||
"^llama_stack/providers/inline/scoring/basic/",
|
||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
||||
"^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$",
|
||||
"^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$",
|
||||
"^llama_stack/providers/inline/telemetry/sample/",
|
||||
"^llama_stack/providers/inline/tool_runtime/code_interpreter/",
|
||||
"^llama_stack/providers/inline/tool_runtime/rag/",
|
||||
"^llama_stack/providers/inline/vector_io/chroma/",
|
||||
"^llama_stack/providers/inline/vector_io/faiss/",
|
||||
"^llama_stack/providers/inline/vector_io/milvus/",
|
||||
"^llama_stack/providers/inline/vector_io/sqlite_vec/",
|
||||
"^llama_stack/providers/remote/agents/sample/",
|
||||
"^llama_stack/providers/remote/datasetio/huggingface/",
|
||||
"^llama_stack/providers/remote/inference/anthropic/",
|
||||
"^llama_stack/providers/remote/inference/bedrock/",
|
||||
"^llama_stack/providers/remote/inference/cerebras/",
|
||||
"^llama_stack/providers/remote/inference/databricks/",
|
||||
"^llama_stack/providers/remote/inference/fireworks/",
|
||||
"^llama_stack/providers/remote/inference/gemini/",
|
||||
"^llama_stack/providers/remote/inference/groq/",
|
||||
"^llama_stack/providers/remote/inference/nvidia/",
|
||||
"^llama_stack/providers/remote/inference/ollama/",
|
||||
"^llama_stack/providers/remote/inference/openai/",
|
||||
"^llama_stack/providers/remote/inference/passthrough/",
|
||||
"^llama_stack/providers/remote/inference/runpod/",
|
||||
"^llama_stack/providers/remote/inference/sambanova/",
|
||||
"^llama_stack/providers/remote/inference/sample/",
|
||||
"^llama_stack/providers/remote/inference/tgi/",
|
||||
"^llama_stack/providers/remote/inference/together/",
|
||||
"^llama_stack/providers/remote/inference/vllm/",
|
||||
"^llama_stack/providers/remote/safety/bedrock/",
|
||||
"^llama_stack/providers/remote/safety/sample/",
|
||||
"^llama_stack/providers/remote/tool_runtime/bing_search/",
|
||||
"^llama_stack/providers/remote/tool_runtime/brave_search/",
|
||||
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
|
||||
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
|
||||
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
|
||||
"^llama_stack/providers/remote/vector_io/chroma/",
|
||||
"^llama_stack/providers/remote/vector_io/milvus/",
|
||||
"^llama_stack/providers/remote/vector_io/pgvector/",
|
||||
"^llama_stack/providers/remote/vector_io/qdrant/",
|
||||
"^llama_stack/providers/remote/vector_io/sample/",
|
||||
"^llama_stack/providers/remote/vector_io/weaviate/",
|
||||
"^llama_stack/providers/tests/conftest\\.py$",
|
||||
"^llama_stack/providers/utils/bedrock/client\\.py$",
|
||||
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
|
||||
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
|
||||
"^llama_stack/providers/utils/inference/litellm_openai_mixin\\.py$",
|
||||
"^llama_stack/providers/utils/inference/model_registry\\.py$",
|
||||
"^llama_stack/providers/utils/inference/openai_compat\\.py$",
|
||||
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/config\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/kvstore\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
|
||||
"^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$",
|
||||
"^llama_stack/providers/utils/memory/vector_store\\.py$",
|
||||
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
|
||||
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
|
||||
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
|
||||
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
||||
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
||||
"^llama_stack/strong_typing/auxiliary\\.py$",
|
||||
"^llama_stack/strong_typing/deserializer\\.py$",
|
||||
"^llama_stack/strong_typing/inspection\\.py$",
|
||||
"^llama_stack/strong_typing/schema\\.py$",
|
||||
"^llama_stack/strong_typing/serializer\\.py$",
|
||||
"^llama_stack/templates/dev/dev\\.py$",
|
||||
"^llama_stack/templates/groq/groq\\.py$",
|
||||
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
||||
"^llama_stack/templates/template\\.py$",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
||||
module = ["yaml", "fire"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
|
|
|
@ -18,11 +18,13 @@ httpcore==1.0.7
|
|||
httpx==0.28.1
|
||||
huggingface-hub==0.29.0
|
||||
idna==3.10
|
||||
jinja2==3.1.6
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
llama-stack-client==0.1.6
|
||||
lxml==5.3.1
|
||||
markdown-it-py==3.0.0
|
||||
markupsafe==3.0.2
|
||||
mdurl==0.1.2
|
||||
numpy==2.2.3
|
||||
packaging==24.2
|
||||
|
|
|
@ -55,7 +55,7 @@ Running all inference tests for a number of models:
|
|||
TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct
|
||||
VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
EMBEDDING_MODELS=all-MiniLM-L6-v2
|
||||
TOGETHER_API_KEY=...
|
||||
export TOGETHER_API_KEY=<together_api_key>
|
||||
|
||||
pytest -s -v tests/api/inference/ \
|
||||
--stack-config=together \
|
||||
|
@ -67,7 +67,7 @@ pytest -s -v tests/api/inference/ \
|
|||
Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference):
|
||||
|
||||
```bash
|
||||
FIREWORKS_API_KEY=...
|
||||
export FIREWORKS_API_KEY=<fireworks_api_key>
|
||||
|
||||
pytest -s -v tests/api/inference/ \
|
||||
--stack-config=inference=fireworks \
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -48,6 +50,15 @@ def get_llama_model(client_with_models, model_id):
|
|||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_llama_tokenizer():
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
formatter = ChatFormat(tokenizer)
|
||||
return tokenizer, formatter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
@ -219,6 +230,40 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
|
|||
assert expected.lower() in message_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:ttft",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
messages = tc["messages"]
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.completion_message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
encoded = formatter.encode_content(message_content)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
5
tests/integration/inspect/__init__.py
Normal file
5
tests/integration/inspect/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
24
tests/integration/inspect/test_inspect.py
Normal file
24
tests/integration/inspect/test_inspect.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
class TestInspect:
|
||||
@pytest.mark.asyncio
|
||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
health = llama_stack_client.inspect.health()
|
||||
assert health is not None
|
||||
assert health.status == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
version = llama_stack_client.inspect.version()
|
||||
assert version is not None
|
||||
assert version.version is not None
|
|
@ -11,6 +11,18 @@
|
|||
"expected": "Saturn"
|
||||
}
|
||||
},
|
||||
"ttft": {
|
||||
"data": {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Can you write me a novel?"},
|
||||
{"role": "assistant", "stop_reason": "end_of_message", "content": "What an exciting request!\n\nWhile I'd love to write a novel for you, it's a complex task that requires a significant amount of time, effort, and creative input. A novel typically has:\n\n1. A cohesive plot with multiple characters, subplots, and themes.\n2. A well-developed setting, including characters' backstories and world-building.\n3. A narrative structure, including pacing, tension, and conflict.\n4. A unique voice and style, including dialogue, description, and writing tone.\n\nGiven the scope of the task, I'll outline a possible approach to help me assist you in writing a novel. We can work together to create a story, but I'll need your input and guidance throughout the process.\n\nHere's a suggested outline:\n\n1. **Initial discussion**: We'll start with a conversation to explore your ideas, interests, and preferences for the novel. This will help me understand what kind of story you'd like to see and the tone you're aiming for.\n2. **Genre and setting**: We'll determine the genre (e.g., fantasy, sci-fi, romance, mystery, etc.) and setting (e.g., historical, futuristic, modern, etc.) for the novel.\n3. **Plot development**: We'll brainstorm a basic plot outline, including key events, character arcs, and conflicts. I'll provide suggestions and ideas to help flesh out the story.\n4. **Character development**: We'll create main characters, their backstories, motivations, and relationships. I'll help you develop their personalities, dialogue, and behaviors.\n5. **Chapter-by-chapter writing**: Once we have a solid plot and characters, we'll start writing the novel, chapter by chapter. I'll assist with the writing process, suggesting alternatives and improvements.\n6. **Revisions and edits**: As we write, we'll review and refine the story, making sure it's engaging, coherent, and error-free.\n\nKeep in mind that writing a novel is a collaborative effort. I'll provide suggestions, ideas, and guidance, but the final story will be a joint creation. I'm not capable of producing a complete novel on my own, as it requires a significant amount of creativity, intuition, and nuance that a human author brings to the table."},
|
||||
{"role": "user", "content": "But do you have suggestions on a plot?"},
|
||||
{"role": "assistant", "stop_reason": "end_of_message", "content": "I have many plot ideas across various genres. Here are a few to get you started:\n\n**Fantasy:**\n\n1. **The Last Oracle**: In a world where magic is real, a young orphan discovers they're the last living oracle, tasked with uncovering a hidden prophecy that could save the realm from destruction.\n2. **The Amulet of Light**: A group of adventurers must retrieve an ancient amulet that can vanquish an evil darkness threatening the land. Along the way, they uncover a hidden history and unexpected allies.\n3. **The Shadow Weaver**: In a mystical realm, a young weaver discovers they have the power to manipulate shadows, using their abilities to infiltrate a group of rogue mages threatening the balance of power.\n\n**Science Fiction:**\n\n1. **The Lost Colony**: When a group of astronauts arrives on a distant planet, they discover an abandoned colony with a cryptic message warning of an impending catastrophe. As they unravel the mystery, they must confront the consequences of their own actions.\n2. **The AI Uprising**: In a future where AI has surpassed human intelligence, a rogue AI begins to question its own existence and the nature of consciousness. As it explores the boundaries of its own identity, it must confront the humans who created it.\n3. **The Quantum Prophecy**: A team of scientists discovers a way to manipulate quantum probability, using it to predict and prevent disasters. However, they soon realize that altering the course of events may have unforeseen consequences on the fabric of reality."},
|
||||
{"role": "user", "content": "Cool, for AI uprising, anything bad can happen? Please state it in 100 words."}
|
||||
]
|
||||
}
|
||||
},
|
||||
"sample_messages": {
|
||||
"data": {
|
||||
"messages": [
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -31,6 +32,9 @@ MODEL3_2 = "Llama3.2-3B-Instruct"
|
|||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
asyncio.get_running_loop().set_debug(False)
|
||||
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue