Merge branch 'refs/heads/main' into preprocessors

This commit is contained in:
ilya-kolchinsky 2025-03-11 20:05:52 +01:00
commit d38aea33c1
37 changed files with 493 additions and 255 deletions

View file

@ -14,16 +14,16 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10.16'
python-version: '3.10'
- uses: astral-sh/setup-uv@v5
with:
python-version: '3.10.16'
python-version: '3.10'
enable-cache: false
- name: Run unit tests
run: |
uv run -p 3.10.16 --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report.xml
uv run -p 3.10 --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report.xml
- name: Upload test results
if: always()

View file

@ -159,8 +159,7 @@ uv run sphinx-autobuild source build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash
uv sync --extra dev
uv run ./docs/openapi_generator/run_openapi_generator.sh
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -363,6 +363,37 @@
}
},
"/v1/agents": {
"get": {
"responses": {
"200": {
"description": "A ListAgentsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List all agents.",
"parameters": []
},
"post": {
"responses": {
"200": {
@ -609,6 +640,47 @@
}
},
"/v1/agents/{agent_id}": {
"get": {
"responses": {
"200": {
"description": "An Agent of the agent.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Agent"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "Describe an agent by its ID.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "ID of the agent.",
"required": true,
"schema": {
"type": "string"
}
}
]
},
"delete": {
"responses": {
"200": {
@ -2358,6 +2430,49 @@
]
}
},
"/v1/agents/{agent_id}/sessions": {
"get": {
"responses": {
"200": {
"description": "A ListAgentSessionsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentSessionsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List all session(s) of a given agent.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to list sessions for.",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/eval/benchmarks": {
"get": {
"responses": {
@ -6776,6 +6891,28 @@
"title": "ScoringResult",
"description": "A scoring result for a single row."
},
"Agent": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"agent_config": {
"$ref": "#/components/schemas/AgentConfig"
},
"created_at": {
"type": "string",
"format": "date-time"
}
},
"additionalProperties": false,
"required": [
"agent_id",
"agent_config",
"created_at"
],
"title": "Agent"
},
"Session": {
"type": "object",
"properties": {
@ -8214,6 +8351,38 @@
],
"title": "ToolInvocationResult"
},
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Session"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentSessionsResponse"
},
"ListAgentsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Agent"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentsResponse"
},
"BucketResponse": {
"type": "object",
"properties": {

View file

@ -238,6 +238,28 @@ paths:
$ref: '#/components/schemas/CompletionRequest'
required: true
/v1/agents:
get:
responses:
'200':
description: A ListAgentsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: List all agents.
parameters: []
post:
responses:
'200':
@ -410,6 +432,34 @@ paths:
$ref: '#/components/schemas/CreateUploadSessionRequest'
required: true
/v1/agents/{agent_id}:
get:
responses:
'200':
description: An Agent of the agent.
content:
application/json:
schema:
$ref: '#/components/schemas/Agent'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: Describe an agent by its ID.
parameters:
- name: agent_id
in: path
description: ID of the agent.
required: true
schema:
type: string
delete:
responses:
'200':
@ -1581,6 +1631,36 @@ paths:
required: true
schema:
type: string
/v1/agents/{agent_id}/sessions:
get:
responses:
'200':
description: A ListAgentSessionsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentSessionsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: List all session(s) of a given agent.
parameters:
- name: agent_id
in: path
description: >-
The ID of the agent to list sessions for.
required: true
schema:
type: string
/v1/eval/benchmarks:
get:
responses:
@ -4690,6 +4770,22 @@ components:
- aggregated_results
title: ScoringResult
description: A scoring result for a single row.
Agent:
type: object
properties:
agent_id:
type: string
agent_config:
$ref: '#/components/schemas/AgentConfig'
created_at:
type: string
format: date-time
additionalProperties: false
required:
- agent_id
- agent_config
- created_at
title: Agent
Session:
type: object
properties:
@ -5579,6 +5675,28 @@ components:
required:
- content
title: ToolInvocationResult
ListAgentSessionsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Session'
additionalProperties: false
required:
- data
title: ListAgentSessionsResponse
ListAgentsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Agent'
additionalProperties: false
required:
- data
title: ListAgentsResponse
BucketResponse:
type: object
properties:

View file

@ -1,9 +1 @@
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
Please install the following packages before running the script:
```
pip install fire PyYAML
```
Then simply run `sh run_openapi_generator.sh`

View file

@ -33,6 +33,8 @@ Can be set to any of the following log levels:
The default global log level is `info`. `all` sets the log level for all components.
A user can also set `LLAMA_STACK_LOG_FILE` which will pipe the logs to the specified path as well as to the terminal. An example would be: `export LLAMA_STACK_LOG_FILE=server.log`
### Llama Stack Build
In order to build your own distribution, we recommend you clone the `llama-stack` repository.

View file

@ -41,7 +41,6 @@ The following models are available by default:
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`

View file

@ -1,6 +1,6 @@
# llama (server-side) CLI Reference
The `llama` CLI tool helps you setup and use the Llama Stack. It should be available on your path after installing the `llama-stack` package.
The `llama` CLI tool helps you set up and use the Llama Stack. The CLI is available on your path after installing the `llama-stack` package.
## Installation
@ -27,9 +27,9 @@ You have two ways to install Llama Stack:
## `llama` subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
2. `model`: Lists available models and their properties.
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](../../distributions/building_distro).
1. `download`: Supports downloading models from Meta or Hugging Face. [Downloading models](#downloading-models)
2. `model`: Lists available models and their properties. [Understanding models](#understand-the-models)
3. `stack`: Allows you to build a stack using the `llama stack` distribution and run a Llama Stack server. You can read more about how to build a Llama Stack distribution in the [Build your own Distribution](../../distributions/building_distro) documentation.
### Sample Usage
@ -117,7 +117,7 @@ You should see a table like this:
+----------------------------------+------------------------------------------+----------------+
```
To download models, you can use the llama download command.
To download models, you can use the `llama download` command.
### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
@ -191,7 +191,7 @@ You should see a table like this:
The `llama model` command helps you explore the models interface.
1. `download`: Download the model from different sources. (meta, huggingface)
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
2. `list`: Lists all the models available for download with hardware requirements for deploying the models.
3. `prompt-format`: Show llama model message formats.
4. `describe`: Describes all the properties of the model.
@ -262,13 +262,12 @@ llama model prompt-format -m Llama3.2-3B-Instruct
![alt text](../../../resources/prompt-format.png)
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens.
### Remove model
You can run `llama model remove` to remove unecessary model:
You can run `llama model remove` to remove an unnecessary model:
```
llama model remove -m Llama-Guard-3-8B-int8

View file

@ -234,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
response_format: Optional[ResponseFormat] = None
@json_schema_type
class Agent(BaseModel):
agent_id: str
agent_config: AgentConfig
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
@ -541,3 +558,32 @@ class Agents(Protocol):
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
"""List all agents.
:returns: A ListAgentsResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET")
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
"""
...

View file

@ -56,6 +56,7 @@ class StackRun(Subcommand):
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
self.parser.add_argument(
@ -73,6 +74,7 @@ class StackRun(Subcommand):
type=str,
help="Image Type used during the build. This can be either conda or container or venv.",
choices=["conda", "container", "venv"],
default="conda",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
@ -118,40 +120,18 @@ class StackRun(Subcommand):
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not args.image_type and not args.image_name:
logger.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.distribution.server.server import main as server_main
# Build the server args from the current args passed to the CLI
server_args = argparse.Namespace()
for arg in vars(args):
# If this is a function, avoid passing it
# "args" contains:
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
else:
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
if args.env:
for env_var in args.env:
if "=" not in env_var:
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:

View file

@ -6,18 +6,16 @@
import argparse
import asyncio
import functools
import inspect
import json
import os
import signal
import sys
import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, List, Optional, Union
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
@ -118,36 +116,12 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
def handle_signal(app, signum, _) -> None:
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
Handle incoming signals and initiate a graceful shutdown of the application.
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
@ -158,38 +132,16 @@ def handle_signal(app, signum, _) -> None:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
yield
logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -266,7 +218,7 @@ class TracingMiddleware:
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
@ -314,17 +266,11 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def main(args: Optional[argparse.Namespace] = None):
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
@ -354,20 +300,8 @@ def main(args: Optional[argparse.Namespace] = None):
required="--tls-keyfile" in sys.argv,
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
# Check for deprecated argument usage
if "--yaml-config" in sys.argv:
warnings.warn(
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning,
stacklevel=2,
)
if args.env:
for env_pair in args.env:
try:
@ -378,9 +312,9 @@ def main(args: Optional[argparse.Namespace] = None):
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.config:
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}")
@ -457,8 +391,6 @@ def main(args: Optional[argparse.Namespace] = None):
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls
@ -489,6 +421,7 @@ def main(args: Optional[argparse.Namespace] = None):
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
}
if ssl_config:
uvicorn_config.update(ssl_config)

View file

@ -97,12 +97,13 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup
def setup_logging(category_levels: Dict[str, int]) -> None:
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
"""
Configure logging based on the provided category log levels.
Configure logging based on the provided category log levels and an optional log file.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
log_file (str): Path to a log file to additionally pipe the logs into
"""
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
@ -117,6 +118,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
handlers = {
"console": {
"()": CustomRichHandler, # Use custom console handler
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
}
# Add a file handler if log_file is set
if log_file:
handlers["file"] = {
"class": "logging.FileHandler",
"formatter": "rich",
"filename": log_file,
"mode": "a",
"encoding": "utf-8",
}
logging_config = {
"version": 1,
"disable_existing_loggers": False,
@ -126,17 +149,7 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
"format": log_format,
}
},
"handlers": {
"console": {
"()": CustomRichHandler, # Use our custom handler class
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
},
"handlers": handlers,
"filters": {
"category_filter": {
"()": CategoryFilter,
@ -144,14 +157,14 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
},
"loggers": {
category: {
"handlers": ["console"],
"handlers": list(handlers.keys()), # Apply all handlers
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": ["console"],
"handlers": list(handlers.keys()),
"level": root_level, # Set root logger's level dynamically
},
}
@ -180,4 +193,6 @@ if env_config:
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
_category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels)
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
setup_logging(_category_levels, log_file)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps,
_deps: Dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Union
from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
_deps: Dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Dict
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune(
self,
job_uuid: str,

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps):
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -3,11 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,

View file

@ -128,16 +128,6 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks

View file

@ -186,16 +186,6 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks

View file

@ -140,16 +140,6 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks

View file

@ -134,16 +134,6 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks