mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
Merge remote-tracking branch 'origin/main' into math_500_judge
This commit is contained in:
commit
8eff285017
87 changed files with 2299 additions and 495 deletions
6
.github/workflows/unit-tests.yml
vendored
6
.github/workflows/unit-tests.yml
vendored
|
@ -14,16 +14,16 @@ jobs:
|
|||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10.16'
|
||||
python-version: '3.10'
|
||||
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10.16'
|
||||
python-version: '3.10'
|
||||
enable-cache: false
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
uv run -p 3.10.16 --with . --with ".[dev]" --with ".[test]" pytest -s -v tests/unit/ --junitxml=pytest-report.xml
|
||||
uv run -p 3.10 --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report.xml
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
|
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -12,12 +12,14 @@ on:
|
|||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
|
||||
jobs:
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -21,3 +21,4 @@ docs/src
|
|||
pyrightconfig.json
|
||||
venv/
|
||||
pytest-report.xml
|
||||
.coverage
|
||||
|
|
|
@ -15,10 +15,6 @@ repos:
|
|||
- id: end-of-file-fixer
|
||||
exclude: '^(.*\.svg)$'
|
||||
|
||||
# Temporarily disabling this
|
||||
# - id: no-commit-to-branch
|
||||
# args: ['--branch=main']
|
||||
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
rev: v1.5.4
|
||||
hooks:
|
||||
|
@ -68,12 +64,6 @@ repos:
|
|||
- pydantic
|
||||
pass_filenames: false
|
||||
|
||||
# - repo: https://github.com/jsh9/pydoclint
|
||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||
# hooks:
|
||||
# - id: pydoclint
|
||||
# args: [--config=pyproject.toml]
|
||||
|
||||
# - repo: https://github.com/tcort/markdown-link-check
|
||||
# rev: v3.11.2
|
||||
# hooks:
|
||||
|
|
33
CHANGELOG.md
33
CHANGELOG.md
|
@ -1,5 +1,38 @@
|
|||
# Changelog
|
||||
|
||||
# v0.1.6
|
||||
Published on: 2025-03-08T04:35:08Z
|
||||
|
||||
## 0.1.6 Release Notes
|
||||
|
||||
### Build and Test Agents
|
||||
* Inference: Fixed support for inline vllm provider
|
||||
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
||||
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
||||
* Agent: Unify tools and Python SDK Agents API
|
||||
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
||||
* Agent: Support python functions without @client_tool decorator as client tools
|
||||
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
||||
* VectorIO: MilvusDB support added
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
||||
* Eval: Documentation for eval, scoring, adding new benchmarks
|
||||
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
||||
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
||||
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
||||
|
||||
### Deploy and Monitoring of Agents
|
||||
* Better support for different log levels across all components for better monitoring
|
||||
|
||||
### Better Engineering
|
||||
* Enhance OpenAPI spec to include Error types across all APIs
|
||||
* Moved all tests to /tests and created unit tests to run on each PR
|
||||
* Removed all dependencies on llama-models repo
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.5.1
|
||||
Published on: 2025-02-28T22:37:44Z
|
||||
|
||||
|
|
|
@ -159,8 +159,7 @@ uv run sphinx-autobuild source build/html --write-all
|
|||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
||||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
|
|
@ -427,6 +427,7 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
@ -448,7 +449,6 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn"
|
||||
|
|
|
@ -71,7 +71,6 @@ services:
|
|||
condition: service_healthy
|
||||
- vllm-${VLLM_SAFETY_MODEL:+safety}:
|
||||
condition: service_healthy
|
||||
# image: llamastack/distribution-remote-vllm
|
||||
image: llamastack/distribution-remote-vllm:test-0.0.52rc3
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
|
|
169
docs/_static/llama-stack-spec.html
vendored
169
docs/_static/llama-stack-spec.html
vendored
|
@ -363,6 +363,37 @@
|
|||
}
|
||||
},
|
||||
"/v1/agents": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListAgentsResponse.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListAgentsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "List all agents.",
|
||||
"parameters": []
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -609,6 +640,47 @@
|
|||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "An Agent of the agent.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Agent"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "Describe an agent by its ID.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "ID of the agent.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -2276,6 +2348,49 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}/sessions": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListAgentSessionsResponse.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListAgentSessionsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "List all session(s) of a given agent.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the agent to list sessions for.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/eval/benchmarks": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -6565,6 +6680,28 @@
|
|||
"title": "ScoringResult",
|
||||
"description": "A scoring result for a single row."
|
||||
},
|
||||
"Agent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"agent_config": {
|
||||
"$ref": "#/components/schemas/AgentConfig"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"agent_id",
|
||||
"agent_config",
|
||||
"created_at"
|
||||
],
|
||||
"title": "Agent"
|
||||
},
|
||||
"Session": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -7907,6 +8044,38 @@
|
|||
],
|
||||
"title": "ToolInvocationResult"
|
||||
},
|
||||
"ListAgentSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Session"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListAgentSessionsResponse"
|
||||
},
|
||||
"ListAgentsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Agent"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListAgentsResponse"
|
||||
},
|
||||
"BucketResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
118
docs/_static/llama-stack-spec.yaml
vendored
118
docs/_static/llama-stack-spec.yaml
vendored
|
@ -238,6 +238,28 @@ paths:
|
|||
$ref: '#/components/schemas/CompletionRequest'
|
||||
required: true
|
||||
/v1/agents:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all agents.
|
||||
parameters: []
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -410,6 +432,34 @@ paths:
|
|||
$ref: '#/components/schemas/CreateUploadSessionRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: An Agent of the agent.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: Describe an agent by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: ID of the agent.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -1528,6 +1578,36 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/agents/{agent_id}/sessions:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentSessionsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentSessionsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all session(s) of a given agent.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to list sessions for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/eval/benchmarks:
|
||||
get:
|
||||
responses:
|
||||
|
@ -4549,6 +4629,22 @@ components:
|
|||
- aggregated_results
|
||||
title: ScoringResult
|
||||
description: A scoring result for a single row.
|
||||
Agent:
|
||||
type: object
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
agent_config:
|
||||
$ref: '#/components/schemas/AgentConfig'
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
additionalProperties: false
|
||||
required:
|
||||
- agent_id
|
||||
- agent_config
|
||||
- created_at
|
||||
title: Agent
|
||||
Session:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5385,6 +5481,28 @@ components:
|
|||
required:
|
||||
- content
|
||||
title: ToolInvocationResult
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Session'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentSessionsResponse
|
||||
ListAgentsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentsResponse
|
||||
BucketResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -1267,7 +1267,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -1279,7 +1278,7 @@
|
|||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" model_id=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -1640,7 +1639,7 @@
|
|||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" instructions=\"You are a helpful assistant. Use websearch tool to help answer questions.\",\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"user_prompts = [\n",
|
||||
|
|
|
@ -1,9 +1 @@
|
|||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
|
||||
|
||||
Please install the following packages before running the script:
|
||||
|
||||
```
|
||||
pip install fire PyYAML
|
||||
```
|
||||
|
||||
Then simply run `sh run_openapi_generator.sh`
|
||||
|
|
|
@ -23,9 +23,12 @@ In this example, we will show you how to:
|
|||
|
||||
##### Building a Search Agent
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
|
@ -33,7 +36,7 @@ agent = Agent(
|
|||
tools=["builtin::websearch"],
|
||||
)
|
||||
user_prompts = [
|
||||
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
|
||||
"Which teams played in the NBA Western Conference Finals of 2024. Search the web for the answer.",
|
||||
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
||||
"What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.",
|
||||
]
|
||||
|
|
|
@ -33,6 +33,8 @@ Can be set to any of the following log levels:
|
|||
|
||||
The default global log level is `info`. `all` sets the log level for all components.
|
||||
|
||||
A user can also set `LLAMA_STACK_LOG_FILE` which will pipe the logs to the specified path as well as to the terminal. An example would be: `export LLAMA_STACK_LOG_FILE=server.log`
|
||||
|
||||
### Llama Stack Build
|
||||
|
||||
In order to build your own distribution, we recommend you clone the `llama-stack` repository.
|
||||
|
|
|
@ -40,7 +40,6 @@ The following models are available by default:
|
|||
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||
|
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# llama (server-side) CLI Reference
|
||||
|
||||
The `llama` CLI tool helps you setup and use the Llama Stack. It should be available on your path after installing the `llama-stack` package.
|
||||
The `llama` CLI tool helps you set up and use the Llama Stack. The CLI is available on your path after installing the `llama-stack` package.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -27,9 +27,9 @@ You have two ways to install Llama Stack:
|
|||
|
||||
|
||||
## `llama` subcommands
|
||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||
2. `model`: Lists available models and their properties.
|
||||
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](../../distributions/building_distro).
|
||||
1. `download`: Supports downloading models from Meta or Hugging Face. [Downloading models](#downloading-models)
|
||||
2. `model`: Lists available models and their properties. [Understanding models](#understand-the-models)
|
||||
3. `stack`: Allows you to build a stack using the `llama stack` distribution and run a Llama Stack server. You can read more about how to build a Llama Stack distribution in the [Build your own Distribution](../../distributions/building_distro) documentation.
|
||||
|
||||
### Sample Usage
|
||||
|
||||
|
@ -117,7 +117,7 @@ You should see a table like this:
|
|||
+----------------------------------+------------------------------------------+----------------+
|
||||
```
|
||||
|
||||
To download models, you can use the llama download command.
|
||||
To download models, you can use the `llama download` command.
|
||||
|
||||
### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
|
||||
|
||||
|
@ -191,7 +191,7 @@ You should see a table like this:
|
|||
The `llama model` command helps you explore the model’s interface.
|
||||
|
||||
1. `download`: Download the model from different sources. (meta, huggingface)
|
||||
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
|
||||
2. `list`: Lists all the models available for download with hardware requirements for deploying the models.
|
||||
3. `prompt-format`: Show llama model message formats.
|
||||
4. `describe`: Describes all the properties of the model.
|
||||
|
||||
|
@ -262,13 +262,12 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
|||

|
||||
|
||||
|
||||
|
||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||
|
||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||
|
||||
### Remove model
|
||||
You can run `llama model remove` to remove unecessary model:
|
||||
You can run `llama model remove` to remove an unnecessary model:
|
||||
|
||||
```
|
||||
llama model remove -m Llama-Guard-3-8B-int8
|
||||
|
|
|
@ -40,7 +40,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
ollama run llama3.2:3b-instruct-fp16 --keepalive -1m
|
||||
```
|
||||
**Note**:
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43)
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py)
|
||||
- `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again.
|
||||
|
||||
---
|
||||
|
|
|
@ -234,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
@ -541,3 +558,32 @@ class Agents(Protocol):
|
|||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -39,7 +39,7 @@ from llama_stack.distribution.resolver import InvalidProviderError
|
|||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
|
@ -226,7 +226,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
|
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -95,7 +95,7 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -104,7 +104,7 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -112,7 +112,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -160,6 +163,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
|
@ -262,21 +268,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
if self.provider_data:
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
@ -374,9 +384,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
|
@ -4,16 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
@ -26,7 +72,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||
val = _provider_data_var.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
@ -36,25 +82,32 @@ class NeedsRequestProviderData:
|
|||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
|
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
@ -29,7 +27,10 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -115,69 +116,24 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -185,8 +141,7 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -202,16 +157,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in sse_generator: {e}")
|
||||
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -223,18 +176,20 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
@ -263,7 +218,7 @@ class TracingMiddleware:
|
|||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
|
@ -436,8 +391,6 @@ def main():
|
|||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
@ -468,6 +421,7 @@ def main():
|
|||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
|
|
@ -20,14 +20,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
|
@ -10,7 +10,9 @@ from logging.config import dictConfig
|
|||
from typing import Dict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
@ -82,13 +84,26 @@ class CustomRichHandler(RichHandler):
|
|||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
"""Override emit to handle markup errors gracefully."""
|
||||
try:
|
||||
super().emit(record)
|
||||
except MarkupError:
|
||||
original_markup = self.markup
|
||||
self.markup = False
|
||||
try:
|
||||
super().emit(record)
|
||||
finally:
|
||||
self.markup = original_markup
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int]) -> None:
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels.
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
|
@ -103,6 +118,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
|
@ -112,17 +149,7 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use our custom handler class
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
|
@ -130,14 +157,14 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
|
@ -163,7 +190,9 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
setup_logging(_category_levels)
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -16,7 +16,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from rich.markup import escape
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
|
@ -182,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
@ -192,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
@ -391,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
@ -509,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -686,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
@ -1030,7 +1029,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {escape(str(result))}")
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
url = str(dataset_info.dataset_def.url.uri)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
|
@ -3,11 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
await self._save_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
|
|
|
@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
|
|
@ -24,10 +24,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from together import Together
|
||||
from together import AsyncTogether
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -91,34 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
client = self._get_client()
|
||||
r = await client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
client = await self._get_client()
|
||||
stream = await client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
|
@ -183,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
stream = await client.chat.completions.create(**params)
|
||||
else:
|
||||
stream = await client.completions.create(**params)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
|
@ -239,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
r = self._get_client().embeddings.create(
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -77,12 +77,13 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
"q": kwargs["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -74,8 +74,13 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": kwargs["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=url,
|
||||
params=payload,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -73,11 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(params=params, url=self.url)
|
||||
response.raise_for_status()
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
|
|
|
@ -615,6 +615,14 @@ def convert_tool_call(
|
|||
return valid_tool_call
|
||||
|
||||
|
||||
PYTHON_TYPE_TO_LITELLM_TYPE = {
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"str": "string",
|
||||
}
|
||||
|
||||
|
||||
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||
"""
|
||||
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
|
||||
|
@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = {"type": param.param_type}
|
||||
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
|
|
|
@ -8,9 +8,11 @@ import logging
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +32,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
"password": self.config.password,
|
||||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = MongoClient(**conn_creds)
|
||||
self.conn = AsyncMongoClient(**conn_creds)
|
||||
self.collection = self.conn[self.config.db][self.config.collection_name]
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
|
@ -44,17 +46,17 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
return result["value"] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.collection.delete_one({"key": key})
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
|
@ -63,4 +65,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
"key": {"$gte": start_key, "$lt": end_key},
|
||||
}
|
||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||
return [doc["value"] for doc in cursor]
|
||||
result = []
|
||||
async for doc in cursor:
|
||||
result.append(doc["value"])
|
||||
return result
|
||||
|
|
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
|
|||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 8):
|
||||
|
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
|
|||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||
BACKGROUND_LOGGER = None
|
||||
|
||||
|
||||
|
@ -51,7 +53,7 @@ class BackgroundLogger:
|
|||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
log.error("Log queue is full, dropping event")
|
||||
logger.error("Log queue is full, dropping event")
|
||||
|
||||
def _process_logs(self):
|
||||
while True:
|
||||
|
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(TelemetryHandler())
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid(16)
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
CURRENT_TRACE_CONTEXT = context
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
logger.debug("No trace context to end")
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
CURRENT_TRACE_CONTEXT.set(None)
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
|
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
||||
|
@ -218,16 +221,22 @@ class SpanContextManager:
|
|||
|
||||
def __enter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.span:
|
||||
|
@ -237,16 +246,22 @@ class SpanContextManager:
|
|||
|
||||
async def __aenter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
@wraps(func)
|
||||
|
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
|||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
return None
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
|
@ -120,16 +120,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -178,16 +178,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -132,16 +132,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -126,16 +126,6 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
|
|
|
@ -5,7 +5,7 @@ distribution_spec:
|
|||
inference:
|
||||
- remote::ollama
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
|
|
|
@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::ollama"],
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
|
@ -43,10 +43,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider_sqlite = Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
vector_io_provider_faiss = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
|
@ -96,7 +96,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
@ -104,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
|
|
@ -17,10 +17,13 @@ providers:
|
|||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -17,10 +17,13 @@ providers:
|
|||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -33,7 +33,7 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
|
@ -190,6 +190,21 @@ datasets:
|
|||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/math_500
|
||||
metadata:
|
||||
path: llamastack/math_500
|
||||
name:
|
||||
split: test
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- benchmark_id: meta-reference-simpleqa
|
||||
|
@ -201,6 +216,9 @@ benchmarks:
|
|||
- benchmark_id: meta-reference-gpqa-cot
|
||||
dataset_id: gpqa_cot
|
||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||
- benchmark_id: meta-reference-math-500
|
||||
dataset_id: math_500
|
||||
scoring_functions: ["basic::regex_parser_math_response"]
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "llama_stack"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||
description = "Llama Stack"
|
||||
readme = "README.md"
|
||||
|
@ -25,8 +25,9 @@ dependencies = [
|
|||
"fire",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.1.5",
|
||||
"llama-stack-client>=0.1.6",
|
||||
"prompt-toolkit",
|
||||
"python-dotenv",
|
||||
"pydantic>=2",
|
||||
|
@ -42,6 +43,7 @@ dependencies = [
|
|||
dev = [
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-html",
|
||||
"nbval", # For notebook testing
|
||||
"black",
|
||||
|
@ -53,20 +55,24 @@ dev = [
|
|||
"fastapi",
|
||||
"ruamel.yaml", # needed for openapi generator
|
||||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet"]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||
# dependencies.
|
||||
test = [
|
||||
"openai",
|
||||
"aiosqlite",
|
||||
"sqlite-vec",
|
||||
"ollama",
|
||||
"torch>=2.6.0",
|
||||
"fairscale>=0.4.13",
|
||||
"torchvision>=0.21.0",
|
||||
"lm-format-enforcer>=0.10.9",
|
||||
"groq",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"mcp",
|
||||
"datasets",
|
||||
"autoevals",
|
||||
]
|
||||
docs = [
|
||||
"sphinx-autobuild",
|
||||
|
|
|
@ -18,11 +18,13 @@ httpcore==1.0.7
|
|||
httpx==0.28.1
|
||||
huggingface-hub==0.29.0
|
||||
idna==3.10
|
||||
jinja2==3.1.6
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
llama-stack-client==0.1.5
|
||||
llama-stack-client==0.1.6
|
||||
lxml==5.3.1
|
||||
markdown-it-py==3.0.0
|
||||
markupsafe==3.0.2
|
||||
mdurl==0.1.2
|
||||
numpy==2.2.3
|
||||
packaging==24.2
|
||||
|
|
|
@ -55,7 +55,7 @@ Running all inference tests for a number of models:
|
|||
TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct
|
||||
VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
EMBEDDING_MODELS=all-MiniLM-L6-v2
|
||||
TOGETHER_API_KEY=...
|
||||
export TOGETHER_API_KEY=<together_api_key>
|
||||
|
||||
pytest -s -v tests/api/inference/ \
|
||||
--stack-config=together \
|
||||
|
@ -67,7 +67,7 @@ pytest -s -v tests/api/inference/ \
|
|||
Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference):
|
||||
|
||||
```bash
|
||||
FIREWORKS_API_KEY=...
|
||||
export FIREWORKS_API_KEY=<fireworks_api_key>
|
||||
|
||||
pytest -s -v tests/api/inference/ \
|
||||
--stack-config=inference=fireworks \
|
||||
|
|
|
@ -42,7 +42,7 @@ def provider_data():
|
|||
for key, value in keymap.items():
|
||||
if os.environ.get(key):
|
||||
provider_data[value] = os.environ[key]
|
||||
return provider_data if len(provider_data) > 0 else None
|
||||
return provider_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -42,6 +44,15 @@ def get_llama_model(client_with_models, model_id):
|
|||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_llama_tokenizer():
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
formatter = ChatFormat(tokenizer)
|
||||
return tokenizer, formatter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
@ -213,6 +224,40 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
|
|||
assert expected.lower() in message_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:ttft",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
messages = tc["messages"]
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.completion_message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
encoded = formatter.encode_content(message_content)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
5
tests/integration/inspect/__init__.py
Normal file
5
tests/integration/inspect/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
24
tests/integration/inspect/test_inspect.py
Normal file
24
tests/integration/inspect/test_inspect.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
class TestInspect:
|
||||
@pytest.mark.asyncio
|
||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
health = llama_stack_client.inspect.health()
|
||||
assert health is not None
|
||||
assert health.status == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
version = llama_stack_client.inspect.version()
|
||||
assert version is not None
|
||||
assert version.version is not None
|
|
@ -11,6 +11,18 @@
|
|||
"expected": "Saturn"
|
||||
}
|
||||
},
|
||||
"ttft": {
|
||||
"data": {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Can you write me a novel?"},
|
||||
{"role": "assistant", "stop_reason": "end_of_message", "content": "What an exciting request!\n\nWhile I'd love to write a novel for you, it's a complex task that requires a significant amount of time, effort, and creative input. A novel typically has:\n\n1. A cohesive plot with multiple characters, subplots, and themes.\n2. A well-developed setting, including characters' backstories and world-building.\n3. A narrative structure, including pacing, tension, and conflict.\n4. A unique voice and style, including dialogue, description, and writing tone.\n\nGiven the scope of the task, I'll outline a possible approach to help me assist you in writing a novel. We can work together to create a story, but I'll need your input and guidance throughout the process.\n\nHere's a suggested outline:\n\n1. **Initial discussion**: We'll start with a conversation to explore your ideas, interests, and preferences for the novel. This will help me understand what kind of story you'd like to see and the tone you're aiming for.\n2. **Genre and setting**: We'll determine the genre (e.g., fantasy, sci-fi, romance, mystery, etc.) and setting (e.g., historical, futuristic, modern, etc.) for the novel.\n3. **Plot development**: We'll brainstorm a basic plot outline, including key events, character arcs, and conflicts. I'll provide suggestions and ideas to help flesh out the story.\n4. **Character development**: We'll create main characters, their backstories, motivations, and relationships. I'll help you develop their personalities, dialogue, and behaviors.\n5. **Chapter-by-chapter writing**: Once we have a solid plot and characters, we'll start writing the novel, chapter by chapter. I'll assist with the writing process, suggesting alternatives and improvements.\n6. **Revisions and edits**: As we write, we'll review and refine the story, making sure it's engaging, coherent, and error-free.\n\nKeep in mind that writing a novel is a collaborative effort. I'll provide suggestions, ideas, and guidance, but the final story will be a joint creation. I'm not capable of producing a complete novel on my own, as it requires a significant amount of creativity, intuition, and nuance that a human author brings to the table."},
|
||||
{"role": "user", "content": "But do you have suggestions on a plot?"},
|
||||
{"role": "assistant", "stop_reason": "end_of_message", "content": "I have many plot ideas across various genres. Here are a few to get you started:\n\n**Fantasy:**\n\n1. **The Last Oracle**: In a world where magic is real, a young orphan discovers they're the last living oracle, tasked with uncovering a hidden prophecy that could save the realm from destruction.\n2. **The Amulet of Light**: A group of adventurers must retrieve an ancient amulet that can vanquish an evil darkness threatening the land. Along the way, they uncover a hidden history and unexpected allies.\n3. **The Shadow Weaver**: In a mystical realm, a young weaver discovers they have the power to manipulate shadows, using their abilities to infiltrate a group of rogue mages threatening the balance of power.\n\n**Science Fiction:**\n\n1. **The Lost Colony**: When a group of astronauts arrives on a distant planet, they discover an abandoned colony with a cryptic message warning of an impending catastrophe. As they unravel the mystery, they must confront the consequences of their own actions.\n2. **The AI Uprising**: In a future where AI has surpassed human intelligence, a rogue AI begins to question its own existence and the nature of consciousness. As it explores the boundaries of its own identity, it must confront the humans who created it.\n3. **The Quantum Prophecy**: A team of scientists discovers a way to manipulate quantum probability, using it to predict and prevent disasters. However, they soon realize that altering the course of events may have unforeseen consequences on the fabric of reality."},
|
||||
{"role": "user", "content": "Cool, for AI uprising, anything bad can happen? Please state it in 100 words."}
|
||||
]
|
||||
}
|
||||
},
|
||||
"sample_messages": {
|
||||
"data": {
|
||||
"messages": [
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -31,6 +32,9 @@ MODEL3_2 = "Llama3.2-3B-Instruct"
|
|||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
asyncio.get_running_loop().set_debug(False)
|
||||
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue