Merge branch 'main' into pr1573

This commit is contained in:
Xi Yan 2025-03-14 12:18:00 -07:00
commit 0e2a13da9c
39 changed files with 5311 additions and 407 deletions

80
.github/workflows/integration-tests.yml vendored Normal file
View file

@ -0,0 +1,80 @@
name: Integration tests
on:
pull_request:
push:
branches: [main]
jobs:
ollama:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install Ollama
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: Pull Ollama image
run: |
ollama pull llama3.2:3b-instruct-fp16
- name: Start Ollama in background
run: |
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
uv pip install ollama faiss-cpu
uv pip install -e .
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
for i in {1..30}; do
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
echo "Ollama is running!"
exit 0
fi
sleep 1
done
echo "Ollama failed to start"
ollama ps
ollama.log
exit 1
- name: Start Llama Stack server in background
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source .venv/bin/activate
# TODO: use "llama stack run"
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
echo "Waiting for Llama Stack server..."
for i in {1..30}; do
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
echo " Llama Stack server is up!"
exit 0
fi
sleep 1
done
echo " Llama Stack server failed to start"
cat server.log
exit 1
- name: Run Inference Integration Tests
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
uv run pytest -v tests/integration/inference --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2

View file

@ -51,7 +51,7 @@ jobs:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"

45
.github/workflows/stale_bot.yml vendored Normal file
View file

@ -0,0 +1,45 @@
name: Close stale issues and PRs
on:
schedule:
- cron: '0 0 * * *' # every day at midnight
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
stale:
permissions:
issues: write
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Stale Action
uses: actions/stale@v9
with:
stale-issue-label: 'stale'
stale-issue-message: >
This issue has been automatically marked as stale because it has not had activity within 60 days.
It will be automatically closed if no further activity occurs within 30 days.
close-issue-message: >
This issue has been automatically closed due to inactivity.
Please feel free to reopen if you feel it is still relevant!
days-before-issue-stale: 60
days-before-issue-close: 30
stale-pr-label: 'stale'
stale-pr-message: >
This pull request has been automatically marked as stale because it has not had activity within 60 days.
It will be automatically closed if no further activity occurs within 30 days.
close-pr-message: >
This pull request has been automatically closed due to inactivity.
Please feel free to reopen if you intend to continue working on it!
days-before-pr-stale: 60
days-before-pr-close: 30
operations-per-run: 300

View file

@ -33,7 +33,7 @@ jobs:
- name: Run unit tests
run: |
uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
- name: Upload test results
if: always()

View file

@ -8,6 +8,7 @@ repos:
rev: v5.0.0 # Latest stable version
hooks:
- id: check-merge-conflict
args: ['--assume-in-merge']
- id: trailing-whitespace
exclude: '\.py$' # Exclude Python files as Ruff already handles them
- id: check-added-large-files
@ -82,6 +83,17 @@ repos:
require_serial: true
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
- repo: local
hooks:
- id: openapi-codegen
name: API Spec Codegen
additional_dependencies:
- uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
language: python
pass_filenames: false
require_serial: true
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate

View file

@ -108,6 +108,22 @@ uv run pre-commit run --all-files
> [!CAUTION]
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
## Running unit tests
You can run the unit tests by running:
```bash
source .venv/bin/activate
./scripts/unit-tests.sh
```
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
```
source .venv/bin/activate
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
```
## Adding a new dependency to the project
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:

View file

@ -51,6 +51,10 @@ Here is a list of the various API providers and available distributions that can
| PG Vector | Single Node | | | ✅ | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
| vLLM | Hosted and Single Node | | ✅ | | | |
| OpenAI | Hosted | | ✅ | | | |
| Anthropic | Hosted | | ✅ | | | |
| Gemini | Hosted | | ✅ | | | |
### Distributions

View file

@ -2092,6 +2092,48 @@
}
}
},
"/v1/providers/{provider_id}": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Providers"
],
"description": "",
"parameters": [
{
"name": "provider_id",
"in": "path",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/tool-runtime/invoke": {
"post": {
"responses": {
@ -2660,7 +2702,7 @@
}
}
},
"/v1/inspect/providers": {
"/v1/providers": {
"get": {
"responses": {
"200": {
@ -7965,6 +8007,53 @@
],
"title": "InsertChunksRequest"
},
"ProviderInfo": {
"type": "object",
"properties": {
"api": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"provider_type": {
"type": "string"
},
"config": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"api",
"provider_id",
"provider_type",
"config"
],
"title": "ProviderInfo"
},
"InvokeToolRequest": {
"type": "object",
"properties": {
@ -8226,27 +8315,6 @@
],
"title": "ListModelsResponse"
},
"ProviderInfo": {
"type": "object",
"properties": {
"api": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"provider_type": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"api",
"provider_id",
"provider_type"
],
"title": "ProviderInfo"
},
"ListProvidersResponse": {
"type": "object",
"properties": {
@ -10246,6 +10314,10 @@
{
"name": "PostTraining (Coming Soon)"
},
{
"name": "Providers",
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
},
{
"name": "Safety"
},
@ -10292,6 +10364,7 @@
"Inspect",
"Models",
"PostTraining (Coming Soon)",
"Providers",
"Safety",
"Scoring",
"ScoringFunctions",

View file

@ -1400,6 +1400,34 @@ paths:
schema:
$ref: '#/components/schemas/InsertChunksRequest'
required: true
/v1/providers/{provider_id}:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ProviderInfo'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Providers
description: ''
parameters:
- name: provider_id
in: path
required: true
schema:
type: string
/v1/tool-runtime/invoke:
post:
responses:
@ -1792,7 +1820,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
/v1/inspect/providers:
/v1/providers:
get:
responses:
'200':
@ -5450,6 +5478,32 @@ components:
- vector_db_id
- chunks
title: InsertChunksRequest
ProviderInfo:
type: object
properties:
api:
type: string
provider_id:
type: string
provider_type:
type: string
config:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- api
- provider_id
- provider_type
- config
title: ProviderInfo
InvokeToolRequest:
type: object
properties:
@ -5613,21 +5667,6 @@ components:
required:
- data
title: ListModelsResponse
ProviderInfo:
type: object
properties:
api:
type: string
provider_id:
type: string
provider_type:
type: string
additionalProperties: false
required:
- api
- provider_id
- provider_type
title: ProviderInfo
ListProvidersResponse:
type: object
properties:
@ -6921,6 +6960,9 @@ tags:
- name: Inspect
- name: Models
- name: PostTraining (Coming Soon)
- name: Providers
x-displayName: >-
Providers API for inspecting, listing, and modifying providers and their configurations.
- name: Safety
- name: Scoring
- name: ScoringFunctions
@ -6945,6 +6987,7 @@ x-tagGroups:
- Inspect
- Models
- PostTraining (Coming Soon)
- Providers
- Safety
- Scoring
- ScoringFunctions

View file

@ -61,6 +61,10 @@ A number of "adapters" are available for some popular Inference and Vector Store
| Groq | Hosted |
| SambaNova | Hosted |
| PyTorch ExecuTorch | On-device iOS, Android |
| OpenAI | Hosted |
| Anthropic | Hosted |
| Gemini | Hosted |
**Vector IO API**
| **Provider** | **Environments** |

View file

@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Api(Enum):
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"

View file

@ -11,13 +11,6 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
@json_schema_type
class VersionInfo(BaseModel):
version: str
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type
class VersionInfo(BaseModel):
version: str
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import * # noqa: F401 F403

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
config: Dict[str, Any]
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...

View file

@ -10,7 +10,7 @@ import json
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
if datetime.now(timezone.utc) > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
"-m",
"--model-name",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
"(Run `llama model list` to see a list of valid model names)",
)
self.parser.add_argument(
"-l",
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
]
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
if model_id not in supported_model_ids:
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
self.parser.error(
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"

View file

@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
@ -247,6 +249,25 @@ def sort_providers_by_deps(
)
)
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")

View file

@ -368,6 +368,7 @@ def main():
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,

View file

@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
)
return PromptTemplate(
template_str.lstrip("\n"),
{"today": datetime.now().strftime("%d %B %Y")},
{
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
},
)
def data_examples(self) -> List[Any]:

View file

@ -11,7 +11,7 @@ import re
import secrets
import string
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now().astimezone().isoformat()
now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
start_time = datetime.now(timezone.utc).isoformat()
input_messages = request.messages
output_message = None
@ -295,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -386,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now().astimezone().isoformat()
shield_call_start_time = datetime.now(timezone.utc).isoformat()
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -410,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -433,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -472,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now().astimezone().isoformat()
inference_start_time = datetime.now(timezone.utc).isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
@ -582,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -653,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now().astimezone().isoformat(),
started_at=datetime.now(timezone.utc).isoformat(),
),
)
yield message
@ -670,7 +670,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now().astimezone().isoformat()
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_call = message.tool_calls[0]
tool_result = await self.execute_tool_call_maybe(
session_id,
@ -708,7 +708,7 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)

View file

@ -7,7 +7,7 @@
import json
import logging
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel
@ -36,7 +36,7 @@ class AgentPersistence:
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
started_at=datetime.now(timezone.utc),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
job_status_response.completed_at = datetime.now(timezone.utc)
except Exception:
job_status_response.status = JobStatus.failed

View file

@ -8,7 +8,7 @@ import gc
import logging
import os
import time
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from datetime import datetime
from datetime import datetime, timezone
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)

View file

@ -8,7 +8,7 @@ import json
import os
import sqlite3
import threading
from datetime import datetime
from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
),
)
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
json.dumps(dict(event.attributes)),
),
)

View file

@ -168,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_paths = []
for i, img in enumerate(images):
# create new directory for each day to better organize data:
dump_dname = datetime.today().strftime("%Y-%m-%d")
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
dump_dpath.mkdir(parents=True, exist_ok=True)
# save image into a file

View file

@ -11,7 +11,7 @@ import logging
import queue
import threading
import uuid
from datetime import datetime
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
@ -86,7 +86,7 @@ class TraceContext:
span_id=generate_short_uuid(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(),
start_time=datetime.now(timezone.utc),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(),
timestamp=datetime.now(timezone.utc),
message=self.format(record),
severity=severity(record.levelname),
)

View file

@ -124,14 +124,15 @@ exclude = [
[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"B9", # flake8-bugbear subset
"C", # comprehensions
"E", # pycodestyle
"F", # Pyflakes
"N", # Naming
"W", # Warnings
"I", # isort
"B", # flake8-bugbear
"B9", # flake8-bugbear subset
"C", # comprehensions
"E", # pycodestyle
"F", # Pyflakes
"N", # Naming
"W", # Warnings
"I", # isort
"DTZ", # datetime rules
]
ignore = [
# The following ignores are desired by the project maintainers.
@ -145,6 +146,10 @@ ignore = [
"C901", # Complexity of the function is too high
]
# Ignore the following errors for the following files
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["DTZ"] # Ignore datetime rules for tests
[tool.mypy]
mypy_path = ["llama_stack"]
packages = ["llama_stack"]
@ -170,6 +175,7 @@ exclude = [
"^llama_stack/apis/inspect/inspect\\.py$",
"^llama_stack/apis/models/models\\.py$",
"^llama_stack/apis/post_training/post_training\\.py$",
"^llama_stack/apis/providers/providers\\.py$",
"^llama_stack/apis/resource\\.py$",
"^llama_stack/apis/safety/safety\\.py$",
"^llama_stack/apis/scoring/scoring\\.py$",

19
scripts/unit-tests.sh Executable file
View file

@ -0,0 +1,19 @@
#!/bin/sh
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
PYTHON_VERSION=${PYTHON_VERSION:-3.10}
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
uv python find $PYTHON_VERSION
FOUND_PYTHON=$?
if [ $FOUND_PYTHON -ne 0 ]; then
uv python install $PYTHON_VERSION
fi
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest -s -v tests/unit/ $@

View file

@ -52,6 +52,8 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
If --record-responses is passed, it will call the real APIs and record the responses.
"""
# TODO: will rework this to be more stable
return llama_stack_client
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
logging.warning(
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -36,7 +36,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
"type": "image",
"image": {
"url": {
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
},
@ -65,7 +65,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
"type": "image",
"image": {
"url": {
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
},

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient
class TestProviders:
@pytest.mark.asyncio
def test_list(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
provider_list = llama_stack_client.providers.list()
assert provider_list is not None