forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into pr1573
This commit is contained in:
commit
0e2a13da9c
39 changed files with 5311 additions and 407 deletions
80
.github/workflows/integration-tests.yml
vendored
Normal file
80
.github/workflows/integration-tests.yml
vendored
Normal file
|
@ -0,0 +1,80 @@
|
|||
name: Integration tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
ollama:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama image
|
||||
run: |
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
|
||||
- name: Start Ollama in background
|
||||
run: |
|
||||
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
uv pip install ollama faiss-cpu
|
||||
uv pip install -e .
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||
echo "Ollama is running!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Ollama failed to start"
|
||||
ollama ps
|
||||
ollama.log
|
||||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
# TODO: use "llama stack run"
|
||||
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo " Llama Stack server is up!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo " Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Run Inference Integration Tests
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
uv run pytest -v tests/integration/inference --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
2
.github/workflows/providers-build.yml
vendored
2
.github/workflows/providers-build.yml
vendored
|
@ -51,7 +51,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
|
45
.github/workflows/stale_bot.yml
vendored
Normal file
45
.github/workflows/stale_bot.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
|||
name: Close stale issues and PRs
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # every day at midnight
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Stale Action
|
||||
uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: >
|
||||
This issue has been automatically marked as stale because it has not had activity within 60 days.
|
||||
It will be automatically closed if no further activity occurs within 30 days.
|
||||
close-issue-message: >
|
||||
This issue has been automatically closed due to inactivity.
|
||||
Please feel free to reopen if you feel it is still relevant!
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
stale-pr-label: 'stale'
|
||||
stale-pr-message: >
|
||||
This pull request has been automatically marked as stale because it has not had activity within 60 days.
|
||||
It will be automatically closed if no further activity occurs within 30 days.
|
||||
close-pr-message: >
|
||||
This pull request has been automatically closed due to inactivity.
|
||||
Please feel free to reopen if you intend to continue working on it!
|
||||
days-before-pr-stale: 60
|
||||
days-before-pr-close: 30
|
||||
operations-per-run: 300
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
||||
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
|
|
|
@ -8,6 +8,7 @@ repos:
|
|||
rev: v5.0.0 # Latest stable version
|
||||
hooks:
|
||||
- id: check-merge-conflict
|
||||
args: ['--assume-in-merge']
|
||||
- id: trailing-whitespace
|
||||
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
||||
- id: check-added-large-files
|
||||
|
@ -82,6 +83,17 @@ repos:
|
|||
require_serial: true
|
||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: openapi-codegen
|
||||
name: API Spec Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.6.2
|
||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||
|
|
|
@ -108,6 +108,22 @@ uv run pre-commit run --all-files
|
|||
> [!CAUTION]
|
||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
|
||||
## Running unit tests
|
||||
|
||||
You can run the unit tests by running:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
||||
|
||||
```
|
||||
source .venv/bin/activate
|
||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
## Adding a new dependency to the project
|
||||
|
||||
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
||||
|
|
|
@ -51,6 +51,10 @@ Here is a list of the various API providers and available distributions that can
|
|||
| PG Vector | Single Node | | | ✅ | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||
| OpenAI | Hosted | | ✅ | | | |
|
||||
| Anthropic | Hosted | | ✅ | | | |
|
||||
| Gemini | Hosted | | ✅ | | | |
|
||||
|
||||
|
||||
### Distributions
|
||||
|
||||
|
|
117
docs/_static/llama-stack-spec.html
vendored
117
docs/_static/llama-stack-spec.html
vendored
|
@ -2092,6 +2092,48 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/providers/{provider_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ProviderInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Providers"
|
||||
],
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "provider_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/invoke": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -2660,7 +2702,7 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/inspect/providers": {
|
||||
"/v1/providers": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -7965,6 +8007,53 @@
|
|||
],
|
||||
"title": "InsertChunksRequest"
|
||||
},
|
||||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string"
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"api",
|
||||
"provider_id",
|
||||
"provider_type",
|
||||
"config"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
"InvokeToolRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8226,27 +8315,6 @@
|
|||
],
|
||||
"title": "ListModelsResponse"
|
||||
},
|
||||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"api",
|
||||
"provider_id",
|
||||
"provider_type"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
"ListProvidersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -10246,6 +10314,10 @@
|
|||
{
|
||||
"name": "PostTraining (Coming Soon)"
|
||||
},
|
||||
{
|
||||
"name": "Providers",
|
||||
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
},
|
||||
|
@ -10292,6 +10364,7 @@
|
|||
"Inspect",
|
||||
"Models",
|
||||
"PostTraining (Coming Soon)",
|
||||
"Providers",
|
||||
"Safety",
|
||||
"Scoring",
|
||||
"ScoringFunctions",
|
||||
|
|
75
docs/_static/llama-stack-spec.yaml
vendored
75
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1400,6 +1400,34 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/InsertChunksRequest'
|
||||
required: true
|
||||
/v1/providers/{provider_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProviderInfo'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Providers
|
||||
description: ''
|
||||
parameters:
|
||||
- name: provider_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
@ -1792,7 +1820,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
/v1/inspect/providers:
|
||||
/v1/providers:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -5450,6 +5478,32 @@ components:
|
|||
- vector_db_id
|
||||
- chunks
|
||||
title: InsertChunksRequest
|
||||
ProviderInfo:
|
||||
type: object
|
||||
properties:
|
||||
api:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
config:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
- config
|
||||
title: ProviderInfo
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5613,21 +5667,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListModelsResponse
|
||||
ProviderInfo:
|
||||
type: object
|
||||
properties:
|
||||
api:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
title: ProviderInfo
|
||||
ListProvidersResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6921,6 +6960,9 @@ tags:
|
|||
- name: Inspect
|
||||
- name: Models
|
||||
- name: PostTraining (Coming Soon)
|
||||
- name: Providers
|
||||
x-displayName: >-
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
- name: Safety
|
||||
- name: Scoring
|
||||
- name: ScoringFunctions
|
||||
|
@ -6945,6 +6987,7 @@ x-tagGroups:
|
|||
- Inspect
|
||||
- Models
|
||||
- PostTraining (Coming Soon)
|
||||
- Providers
|
||||
- Safety
|
||||
- Scoring
|
||||
- ScoringFunctions
|
||||
|
|
|
@ -61,6 +61,10 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
|||
| Groq | Hosted |
|
||||
| SambaNova | Hosted |
|
||||
| PyTorch ExecuTorch | On-device iOS, Android |
|
||||
| OpenAI | Hosted |
|
||||
| Anthropic | Hosted |
|
||||
| Gemini | Hosted |
|
||||
|
||||
|
||||
**Vector IO API**
|
||||
| **Provider** | **Environments** |
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
|
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
|
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
|
7
llama_stack/apis/providers/__init__.py
Normal file
7
llama_stack/apis/providers/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .providers import * # noqa: F401 F403
|
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
|
@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
|
|||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
|
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
|
|||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
model_str = "\n".join(model_list)
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
|
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
|
|||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
|
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
59
llama_stack/distribution/providers.py
Normal file
59
llama_stack/distribution/providers.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
|||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
|
@ -247,6 +249,25 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
@ -368,6 +368,7 @@ def main():
|
|||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
|
|
|
@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
{
|
||||
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
|
|
|
@ -11,7 +11,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now().astimezone().isoformat()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
|
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
start_time = datetime.now(timezone.utc).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
|
@ -295,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
@ -386,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -410,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -433,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -472,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now().astimezone().isoformat()
|
||||
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
|
@ -582,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -653,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
@ -670,7 +670,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
|
@ -708,7 +708,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
|
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
|
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
|
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
|
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -168,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
|||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
|
|
|
@ -11,7 +11,7 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
@ -86,7 +86,7 @@ class TraceContext:
|
|||
span_id=generate_short_uuid(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
|
|||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
|
|
|
@ -124,14 +124,15 @@ exclude = [
|
|||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"B", # flake8-bugbear
|
||||
"B9", # flake8-bugbear subset
|
||||
"C", # comprehensions
|
||||
"E", # pycodestyle
|
||||
"F", # Pyflakes
|
||||
"N", # Naming
|
||||
"W", # Warnings
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"B9", # flake8-bugbear subset
|
||||
"C", # comprehensions
|
||||
"E", # pycodestyle
|
||||
"F", # Pyflakes
|
||||
"N", # Naming
|
||||
"W", # Warnings
|
||||
"I", # isort
|
||||
"DTZ", # datetime rules
|
||||
]
|
||||
ignore = [
|
||||
# The following ignores are desired by the project maintainers.
|
||||
|
@ -145,6 +146,10 @@ ignore = [
|
|||
"C901", # Complexity of the function is too high
|
||||
]
|
||||
|
||||
# Ignore the following errors for the following files
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["DTZ"] # Ignore datetime rules for tests
|
||||
|
||||
[tool.mypy]
|
||||
mypy_path = ["llama_stack"]
|
||||
packages = ["llama_stack"]
|
||||
|
@ -170,6 +175,7 @@ exclude = [
|
|||
"^llama_stack/apis/inspect/inspect\\.py$",
|
||||
"^llama_stack/apis/models/models\\.py$",
|
||||
"^llama_stack/apis/post_training/post_training\\.py$",
|
||||
"^llama_stack/apis/providers/providers\\.py$",
|
||||
"^llama_stack/apis/resource\\.py$",
|
||||
"^llama_stack/apis/safety/safety\\.py$",
|
||||
"^llama_stack/apis/scoring/scoring\\.py$",
|
||||
|
|
19
scripts/unit-tests.sh
Executable file
19
scripts/unit-tests.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
#!/bin/sh
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
PYTHON_VERSION=${PYTHON_VERSION:-3.10}
|
||||
|
||||
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
||||
|
||||
uv python find $PYTHON_VERSION
|
||||
FOUND_PYTHON=$?
|
||||
if [ $FOUND_PYTHON -ne 0 ]; then
|
||||
uv python install $PYTHON_VERSION
|
||||
fi
|
||||
|
||||
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest -s -v tests/unit/ $@
|
|
@ -52,6 +52,8 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
|
|||
|
||||
If --record-responses is passed, it will call the real APIs and record the responses.
|
||||
"""
|
||||
# TODO: will rework this to be more stable
|
||||
return llama_stack_client
|
||||
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
logging.warning(
|
||||
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -36,7 +36,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
|||
"type": "image",
|
||||
"image": {
|
||||
"url": {
|
||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
|
||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -65,7 +65,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
|||
"type": "image",
|
||||
"image": {
|
||||
"url": {
|
||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
|
||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
5
tests/integration/providers/__init__.py
Normal file
5
tests/integration/providers/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
17
tests/integration/providers/test_providers.py
Normal file
17
tests/integration/providers/test_providers.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
class TestProviders:
|
||||
@pytest.mark.asyncio
|
||||
def test_list(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
provider_list = llama_stack_client.providers.list()
|
||||
assert provider_list is not None
|
Loading…
Add table
Add a link
Reference in a new issue